diff --git a/.gitignore b/.gitignore index 828bbe9bd3363853ae3f58f54a8d5f60cefad837..b5306b8b79c37166e5496cf17a3e39b86b9a6314 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ __pycache__ cmake_build/ .idea/** /build/ +[Bb]uild/ /tensorflow/core/util/version_info.cc /tensorflow/python/framework/fast_tensor_util.cpp Pods diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8669c25c452b53da48239bc20c9a2d3528e75422..db4b1581ae671b1e676e215c9a80dfaab832fa21 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -90,7 +90,7 @@ Bazel BUILD files also need to include a license section, e.g., Changes to TensorFlow C++ code should conform to [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html). -Use `clang-tidy` to check your C/C++ changes. To install clang-tidy on ubuntu:16.04, do: +Use `clang-tidy` to check your C/C++ changes. To install `clang-tidy` on ubuntu:16.04, do: ```bash apt-get install -y clang-tidy diff --git a/README.md b/README.md index 6fb4486d0de9ff476b5cf1dbd63d66879637df84..63853137cfd30b396f8c7d204811f3e4a1794c07 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ $ python 42 >>> sess.close() ``` +Learn more examples about how to do specific tasks in TensorFlow at the [tutorials page of tensorflow.org](https://www.tensorflow.org/tutorials/). ## Contribution guidelines diff --git a/RELEASE.md b/RELEASE.md index 27f73b7fc6a5240909a00056f910cc2ad304b759..e09e9c6190f57adec67c2ae1d85848dabfd9c2a7 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,62 @@ +# Release 1.9.0 + +## Major Features And Improvements +* Update tf.keras to the Keras 2.1.6 API. +* `tfe.Network` is deprecated. Please inherit from `tf.keras.Model`. +* Adding support of core feature columns and losses to gradient boosted trees estimators. +* The distributions.Bijector API supports broadcasting for Bijectors with new API changes. See [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/distributions/bijectors/Bijector) for more details. +* Layered variable names have changed in the following conditions: + * Using `tf.keras.layers` with custom variable scopes. + * Using `tf.layers` in a subclassed `tf.keras.Model` class. See [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/layers) for more details + +## Breaking Chances + * If you're opening empty variable scopes; replace `variable_scope`('', ...) by `variable_scope`(`tf.get_variable_scope()`, ...). + +## Bug Fixes and Other Changes +* `tf.data`: + * The `DatasetBase::DebugString()` method is now `const`. + * Added the `tf.contrib.data.sample_from_datasets()` API for randomly sampling from multiple datasets. +* Eager Execution: +* `tf.keras`: + * Move Keras code out of _impl folder and remove API files. + * `tf.keras.Model.save_weights` now saves in TensorFlow format by default. + * Enable dataset iterators to be passed to `tf.keras.Model` training/eval methods. +* Accelerated Linear Algebra (XLA): +* TensorFlow Debugger (tfdbg): fix an issue in which the TensorBoard Debugger Plugin could not handle total source file size exceeding gRPC message size limit (4 MB). +* `tf.contrib`: + * Add `tf.contrib.data.choose_from_datasets()`. + * `tf.contrib.data.make_csv_dataset()` now supports line breaks in quoted strings. Two arguments were removed from `make_csv_dataset`. + * `tf.contrib.framework.zero_initializer` supports ResourceVariable. + * Adding "constrained_optimization" to tensorflow/contrib. +* Other: + * Add GCS Configuration Ops. + * Changing signature of `MakeIterator` to enable propagating error status. + * KL divergence for two Dirichlet distributions. + * More consistent GcsFileSystem behavior for certain reads past EOF. + * Update benchmark for tf.scan to match ranges across eager and graph modes. + * Fixed bug in `tf.reduce_prod gradient` for complex dtypes. + * Add optional `args` argument to `Dataset.from_generator()`. + * Allow the use of '.' in variables (e.g. "hparams.parse('a.b=1.0')"), which would previously raise an error. This will correspond to an attribute name with an embedded '.' symbol (e.g. 'a.b'), which can only be accessed indirectly (e.g. through getattr and setattr). To set this up the user will first need to explicitly add the variable to the hparam object (e.g. "hparams.add_hparam(name='a.b', value=0.0)"). + * Benchmark for tf.scan in graph and eager modes. + * Added complex128 support to FFT, FFT2D, FFT3D, IFFT, IFFT2D, and IFFT3D. + * Making ids unique in `nn.embedding_lookup_sparse`. This helps to reduce RPC calls for looking up the embeddings when there are repeated ids in the batch. + * Support indicator column in boosted trees. + * Prevent `tf.gradients()` from backpropagating through integer tensors. + * LinearOperator[1D,2D,3D]Circulant added to `tensorflow.linalg`. + * Conv3D, Conv3DBackpropInput, Conv3DBackpropFilter now supports arbitrary. + * Added `tf.train.Checkpoint` for reading/writing object-based checkpoints. + * `Dataset.list_files()` now produces determinstic results when `shuffle=False` or a `seed` is passed. + * Added LinearOperatorKronecker, a dense-free implementation of the Kronecker Product. + * Allow LinearOperator to broadcast. + * SavedModelBuilder will now deduplicate asset names that point to files with the same basename and the same contents. Note that this may result in new asset files included in SavedModels in cases where assets with the same name but different contents were previously overwriting each other. + + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +Abdullah Alrasheed, Achal Shah, Ad-530, ADiegoCAlonso, Aditya Yogi, Ag Ramesh, akindyakov, Andy Kernahan, Anya Petrova, Aurelien Geron, Ben, Ben Barsdell, Bhavani-Subramanian, braincodercn, Brett Koonce, Brian Nemsick, Brian Zier, Bryan Heden, candy.dc, cclauss, Clayne Robison, ctiijima, Dalmo Cirne, David Norman, David T.H. Kao, DosLin, ekelsen, Elson Rodriguez, Erik Smistad, Felix Abecassis, Fergal Cotter, fo40225, foo0x29a, Freedom" Koan-Sin Tan, FréDéRic Branchaud-Charron, gdh1995, Geoffrey Irving, Giuseppe, gracehoney, Guido Zuidhof, Guillaume Klein, Guozhong Zhuang, Haggai, Harald Husum, imsheridan, Ivan Zhang, Jan Zikes, Jayaram Bobba, Jesse Benson, Jesse Gumz, Jiajia Li, Jie, jinghuangintel, Jingwen, jjsjann123, Joe Yearsley, Joel Hestness, Joel Shor, josephyearsley, Junpeng Lao, Karol M. Langner, Kb Sriram, krantideep95, Krish Ravindranath, Letian Feng, Loo Rong Jie, Lukas Geiger, Maciej, Mahmoud Abuzaina, ManHyuk, Mark Ryan, mbhuiyan, Michal Turek, Mostafa Alaa, Myungsung Kwak, Nand Dalal, Nehal J Wani, Neil Tenenholtz, ngc92, Nicholas Nadeau, P.Eng., Avs, Niranjan Hasabnis, P-Hidringer, Paul Van Eck, Peng Yu, Qing Zhao, Qingying Chen, Quanlong, Rajendra Arora, Rholais Lii, rmanyari, Robin Richtsfeld, Russell Klopfer, Sagi, Sam Sendelbach, Sandeep N Gupta, Sandip Giri, Sarah Edkins, Scott Tseng, Sdalbsoo, Sergii Khomenko, Seungwoo Choi (Biggie), Seyed Majid Azimi, Shaoning Zeng, shengfuintel, Siu Kei, Muk, Smit Shilu, soonson, Stefan Schweter, Sukhwan Kim, Sunitha Kambhampati, Taehoon Lee, tamimaddari82, Tang, Wenyi, Ted Chang, u2takey, Utkarsh Upadhyay, Vadim Markovtsev, voegtlel, Wai Hon Law, wangsiyu, Wenhao Hu, wenhao.hu, William D. Irons, Yan Facai (颜发才), Yanbo Liang, Yihong Wang, Yilei (Dolee) Yang, Yong Tang, Yuan (Terry) Tang + # Release 1.8.0 ## Major Features And Improvements diff --git a/SECURITY.md b/SECURITY.md index 0a4be37cbc20665bf8be68616496d35c8b6d7fb7..0b52fdc7ab84b7bd5bce5d247ede81b40699005c 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -242,12 +242,7 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc= -----END PGP PUBLIC KEY BLOCK----- ``` -### Known vulnerabilities - -| Type | Versions affected | Reported by | Additional Information | -|--------------------|:-----------------:|-----------------------|-----------------------------| -| TensorFlow Lite TOCO FlatBuffer Parsing Vulnerability | <= 1.7 | Blade Team of Tencent | [security advisory](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/security/advisory/tfsa-2018-003.md) | -| GIF File Parsing Null Pointer Dereference Error | <= 1.5 | Blade Team of Tencent | [security advisory](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/security/advisory/tfsa-2018-002.md) | -| BMP File Parser Out-of-bounds Read | <= 1.6 | Blade Team of Tencent | [security advisory](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/security/advisory/tfsa-2018-001.md) | -| Out Of Bounds Read | <=1.4 | Blade Team of Tencent | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) | +### Known Vulnerabilities +For a list of known vulnerabilities and security advisories for TensorFlow, +[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md). diff --git a/WORKSPACE b/WORKSPACE index 44baf78f49c6f97a95b67fe9e27d2cd978c2a32a..fd7570a80ae2ee0087f7d2fd771fcce5b9690028 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -22,26 +22,10 @@ check_bazel_version_at_least("0.10.0") load("//tensorflow:workspace.bzl", "tf_workspace") -# Uncomment and update the paths in these entries to build the Android demo. -#android_sdk_repository( -# name = "androidsdk", -# api_level = 23, -# # Ensure that you have the build_tools_version below installed in the -# # SDK manager as it updates periodically. -# build_tools_version = "26.0.1", -# # Replace with path to Android SDK on your system -# path = "", -#) -# -#android_ndk_repository( -# name="androidndk", -# path="", -# # This needs to be 14 or higher to compile TensorFlow. -# # Please specify API level >= 21 to build for 64-bit architecture -# # otherwise the Android NDK will automatically select the latest -# # API level it does support without notice. -# # Note that the NDK version is not the API level. -# api_level=14) +load("//third_party/android:android_configure.bzl", "android_configure") +android_configure(name="local_config_android") +load("@local_config_android//:android.bzl", "android_workspace") +android_workspace() # Please add all new TensorFlow dependencies in workspace.bzl. tf_workspace() diff --git a/configure.py b/configure.py index 96caa2e2dd6f772ebdf8934708d133cdd16514bb..ad585fa52e571d62d11864531476e46b2f15f297 100644 --- a/configure.py +++ b/configure.py @@ -670,8 +670,9 @@ def create_android_ndk_rule(environ_cp): error_msg=('The path %s or its child file "source.properties" ' 'does not exist.') ) - - write_android_ndk_workspace_rule(android_ndk_home_path) + write_action_env_to_bazelrc('ANDROID_NDK_HOME', android_ndk_home_path) + write_action_env_to_bazelrc('ANDROID_NDK_API_LEVEL', + check_ndk_level(android_ndk_home_path)) def create_android_sdk_rule(environ_cp): @@ -733,41 +734,12 @@ def create_android_sdk_rule(environ_cp): error_msg=('The selected SDK does not have build-tools version %s ' 'available.')) - write_android_sdk_workspace_rule(android_sdk_home_path, - android_build_tools_version, - android_api_level) - - -def write_android_sdk_workspace_rule(android_sdk_home_path, - android_build_tools_version, - android_api_level): - print('Writing android_sdk_workspace rule.\n') - with open(_TF_WORKSPACE, 'a') as f: - f.write(""" -android_sdk_repository( - name="androidsdk", - api_level=%s, - path="%s", - build_tools_version="%s")\n -""" % (android_api_level, android_sdk_home_path, android_build_tools_version)) - - -def write_android_ndk_workspace_rule(android_ndk_home_path): - print('Writing android_ndk_workspace rule.') - ndk_api_level = check_ndk_level(android_ndk_home_path) - if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS: - print('WARNING: The API level of the NDK in %s is %s, which is not ' - 'supported by Bazel (officially supported versions: %s). Please use ' - 'another version. Compiling Android targets may result in confusing ' - 'errors.\n' % (android_ndk_home_path, ndk_api_level, - _SUPPORTED_ANDROID_NDK_VERSIONS)) - with open(_TF_WORKSPACE, 'a') as f: - f.write(""" -android_ndk_repository( - name="androidndk", - path="%s", - api_level=%s)\n -""" % (android_ndk_home_path, ndk_api_level)) + write_action_env_to_bazelrc('ANDROID_BUILD_TOOLS_VERSION', + android_build_tools_version) + write_action_env_to_bazelrc('ANDROID_SDK_API_LEVEL', + android_api_level) + write_action_env_to_bazelrc('ANDROID_SDK_HOME', + android_sdk_home_path) def check_ndk_level(android_ndk_home_path): @@ -780,18 +752,16 @@ def check_ndk_level(android_ndk_home_path): revision = re.search(r'Pkg.Revision = (\d+)', filedata) if revision: - return revision.group(1) - return None - - -def workspace_has_any_android_rule(): - """Check the WORKSPACE for existing android_*_repository rules.""" - with open(_TF_WORKSPACE, 'r') as f: - workspace = f.read() - has_any_rule = re.search(r'^android_[ns]dk_repository', - workspace, - re.MULTILINE) - return has_any_rule + ndk_api_level = revision.group(1) + else: + raise Exception('Unable to parse NDK revision.') + if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS: + print('WARNING: The API level of the NDK in %s is %s, which is not ' + 'supported by Bazel (officially supported versions: %s). Please use ' + 'another version. Compiling Android targets may result in confusing ' + 'errors.\n' % (android_ndk_home_path, ndk_api_level, + _SUPPORTED_ANDROID_NDK_VERSIONS)) + return ndk_api_level def set_gcc_host_compiler_path(environ_cp): @@ -973,6 +943,35 @@ def set_tf_cudnn_version(environ_cp): write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version) +def is_cuda_compatible(lib, cuda_ver, cudnn_ver): + """Check compatibility between given library and cudnn/cudart libraries.""" + ldd_bin = which('ldd') or '/usr/bin/ldd' + ldd_out = run_shell([ldd_bin, lib], True) + ldd_out = ldd_out.split(os.linesep) + cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$') + cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$') + cudnn = None + cudart = None + cudnn_ok = True # assume no cudnn dependency by default + cuda_ok = True # assume no cuda dependency by default + for line in ldd_out: + if 'libcudnn.so' in line: + cudnn = cudnn_pattern.search(line) + cudnn_ok = False + elif 'libcudart.so' in line: + cudart = cuda_pattern.search(line) + cuda_ok = False + if cudnn and len(cudnn.group(1)): + cudnn = convert_version_to_int(cudnn.group(1)) + if cudart and len(cudart.group(1)): + cudart = convert_version_to_int(cudart.group(1)) + if cudnn is not None: + cudnn_ok = (cudnn == cudnn_ver) + if cudart is not None: + cuda_ok = (cudart == cuda_ver) + return cudnn_ok and cuda_ok + + def set_tf_tensorrt_install_path(environ_cp): """Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION. @@ -989,8 +988,8 @@ def set_tf_tensorrt_install_path(environ_cp): raise ValueError('Currently TensorRT is only supported on Linux platform.') # Ask user whether to add TensorRT support. - if str(int(get_var( - environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False))) != '1': + if str(int(get_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', + False))) != '1': return for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): @@ -1003,47 +1002,29 @@ def set_tf_tensorrt_install_path(environ_cp): # Result returned from "read" will be used unexpanded. That make "~" # unusable. Going through one more level of expansion to handle that. - trt_install_path = os.path.realpath( - os.path.expanduser(trt_install_path)) + trt_install_path = os.path.realpath(os.path.expanduser(trt_install_path)) def find_libs(search_path): """Search for libnvinfer.so in "search_path".""" fl = set() if os.path.exists(search_path) and os.path.isdir(search_path): - fl.update([os.path.realpath(os.path.join(search_path, x)) - for x in os.listdir(search_path) if 'libnvinfer.so' in x]) + fl.update([ + os.path.realpath(os.path.join(search_path, x)) + for x in os.listdir(search_path) + if 'libnvinfer.so' in x + ]) return fl possible_files = find_libs(trt_install_path) possible_files.update(find_libs(os.path.join(trt_install_path, 'lib'))) possible_files.update(find_libs(os.path.join(trt_install_path, 'lib64'))) - - def is_compatible(tensorrt_lib, cuda_ver, cudnn_ver): - """Check the compatibility between tensorrt and cudnn/cudart libraries.""" - ldd_bin = which('ldd') or '/usr/bin/ldd' - ldd_out = run_shell([ldd_bin, tensorrt_lib]).split(os.linesep) - cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$') - cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$') - cudnn = None - cudart = None - for line in ldd_out: - if 'libcudnn.so' in line: - cudnn = cudnn_pattern.search(line) - elif 'libcudart.so' in line: - cudart = cuda_pattern.search(line) - if cudnn and len(cudnn.group(1)): - cudnn = convert_version_to_int(cudnn.group(1)) - if cudart and len(cudart.group(1)): - cudart = convert_version_to_int(cudart.group(1)) - return (cudnn == cudnn_ver) and (cudart == cuda_ver) - cuda_ver = convert_version_to_int(environ_cp['TF_CUDA_VERSION']) cudnn_ver = convert_version_to_int(environ_cp['TF_CUDNN_VERSION']) nvinfer_pattern = re.compile('.*libnvinfer.so.?(.*)$') highest_ver = [0, None, None] for lib_file in possible_files: - if is_compatible(lib_file, cuda_ver, cudnn_ver): + if is_cuda_compatible(lib_file, cuda_ver, cudnn_ver): matches = nvinfer_pattern.search(lib_file) if len(matches.groups()) == 0: continue @@ -1059,12 +1040,13 @@ def set_tf_tensorrt_install_path(environ_cp): # Try another alternative from ldconfig. ldconfig_bin = which('ldconfig') or '/sbin/ldconfig' ldconfig_output = run_shell([ldconfig_bin, '-p']) - search_result = re.search( - '.*libnvinfer.so\\.?([0-9.]*).* => (.*)', ldconfig_output) + search_result = re.search('.*libnvinfer.so\\.?([0-9.]*).* => (.*)', + ldconfig_output) if search_result: libnvinfer_path_from_ldconfig = search_result.group(2) if os.path.exists(libnvinfer_path_from_ldconfig): - if is_compatible(libnvinfer_path_from_ldconfig, cuda_ver, cudnn_ver): + if is_cuda_compatible(libnvinfer_path_from_ldconfig, cuda_ver, + cudnn_ver): trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig) tf_tensorrt_version = search_result.group(1) break @@ -1223,7 +1205,7 @@ def set_tf_cuda_compute_capabilities(environ_cp): # Check whether all capabilities from the input is valid all_valid = True # Remove all whitespace characters before splitting the string - # that users may insert by accident, as this will result in error + # that users may insert by accident, as this will result in error tf_cuda_compute_capabilities = ''.join(tf_cuda_compute_capabilities.split()) for compute_capability in tf_cuda_compute_capabilities.split(','): m = re.match('[0-9]+.[0-9]+', compute_capability) @@ -1556,21 +1538,15 @@ def main(): set_build_strip_flag() set_windows_build_flags() - if workspace_has_any_android_rule(): - print('The WORKSPACE file has at least one of ["android_sdk_repository", ' - '"android_ndk_repository"] already set. Will not ask to help ' - 'configure the WORKSPACE. Please delete the existing rules to ' - 'activate the helper.\n') - else: - if get_var( - environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', - False, - ('Would you like to interactively configure ./WORKSPACE for ' - 'Android builds?'), - 'Searching for NDK and SDK installations.', - 'Not configuring the WORKSPACE for Android builds.'): - create_android_ndk_rule(environ_cp) - create_android_sdk_rule(environ_cp) + if get_var( + environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', + False, + ('Would you like to interactively configure ./WORKSPACE for ' + 'Android builds?'), + 'Searching for NDK and SDK installations.', + 'Not configuring the WORKSPACE for Android builds.'): + create_android_ndk_rule(environ_cp) + create_android_sdk_rule(environ_cp) print('Preconfigured Bazel build configs. You can use any of the below by ' 'adding "--config=<>" to your build command. See tools/bazel.rc for ' diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 9b07669a5d8e4da6ce202fc9196185b91d8e7e2e..a15d033013f573ca7a182cc72cb4b7a8cec0e273 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -154,6 +154,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "linux_s390x", + values = {"cpu": "s390x"}, + visibility = ["//visibility:public"], +) + config_setting( name = "debug", values = { @@ -398,6 +404,7 @@ config_setting( package_group( name = "internal", packages = [ + "-//third_party/tensorflow/python/estimator", "//learning/meta_rank/...", "//tensorflow/...", "//tensorflow_fold/llgtm/...", @@ -424,6 +431,22 @@ filegroup( data = glob(["docs_src/**/*.md"]), ) +cc_library( + name = "grpc", + deps = select({ + ":linux_s390x": ["@grpc//:grpc_unsecure"], + "//conditions:default": ["@grpc"], + }), +) + +cc_library( + name = "grpc++", + deps = select({ + ":linux_s390x": ["@grpc//:grpc++_unsecure"], + "//conditions:default": ["@grpc//:grpc++"], + }), +) + # A shared object which includes registration mechanisms for ops and # kernels. Does not include the implementations of any ops or kernels. Instead, # the library which loads libtensorflow_framework.so @@ -451,6 +474,15 @@ filegroup( tf_cc_shared_object( name = "libtensorflow_framework.so", framework_so = [], + linkopts = select({ + "//tensorflow:darwin": [], + "//tensorflow:windows": [], + "//tensorflow:windows_msvc": [], + "//conditions:default": [ + "-Wl,--version-script", # This line must be directly followed by the version_script.lds file + "$(location //tensorflow:tf_framework_version_script.lds)", + ], + }), linkstatic = 1, visibility = ["//visibility:public"], deps = [ @@ -460,6 +492,7 @@ tf_cc_shared_object( "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl", "//tensorflow/core:lib_internal_impl", "//tensorflow/stream_executor:stream_executor_impl", + "//tensorflow:tf_framework_version_script.lds", ] + tf_additional_binary_deps(), ) @@ -539,15 +572,27 @@ exports_files( ) gen_api_init_files( - name = "python_api_gen", + name = "tensorflow_python_api_gen", srcs = ["api_template.__init__.py"], root_init_template = "api_template.__init__.py", ) py_library( name = "tensorflow_py", - srcs = [":python_api_gen"], + srcs = ["//tensorflow/python/estimator/api:estimator_python_api_gen"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":tensorflow_py_no_contrib", + "//tensorflow/contrib:contrib_py", + "//tensorflow/python/estimator:estimator_py", + ], +) + +py_library( + name = "tensorflow_py_no_contrib", + srcs = [":tensorflow_python_api_gen"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = ["//tensorflow/python"], + deps = ["//tensorflow/python:no_contrib"], ) diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 9b0d7d48afd058607badc90b95c9dca0c4ceaa31..779f65d5b17c350833f67f07985b00e8eb561e72 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -20,9 +20,25 @@ from __future__ import print_function # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import + +try: + import os # pylint: disable=g-import-not-at-top + # Add `estimator` attribute to allow access to estimator APIs via + # "tf.estimator..." + from tensorflow.python.estimator.api import estimator # pylint: disable=g-import-not-at-top + + # Add `estimator` to the __path__ to allow "from tensorflow.estimator..." + # style imports. + from tensorflow.python.estimator import api as estimator_api # pylint: disable=g-import-not-at-top + __path__ += [os.path.dirname(estimator_api.__file__)] + del estimator_api + del os +except (ImportError, AttributeError): + print('tf.estimator package not installed.') + # API IMPORTS PLACEHOLDER -from tensorflow.python.util.lazy_loader import LazyLoader +from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') del LazyLoader diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index b86b277ac3200b88ae03490a6c1b64d464e81950..12f0d8bff4720d98b7f45b113dc62c881e32a399 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -631,7 +631,22 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in, "Failed to allocate memory to serialize message of type '", in.GetTypeName(), "' and size ", proto_size); } - in.SerializeToArray(buf, proto_size); + // SerializeToArray takes size as an int. + // This next 'if' is a workaround till we update to depend on a version + // of protocol buffers that includes + // https://github.com/google/protobuf/pull/4739 + if (proto_size > std::numeric_limits::max()) { + return InvalidArgument("Cannot serialize protocol buffer of type ", + in.GetTypeName(), " as the serialized size (", + proto_size, + "bytes) would be larger than the limit (", + std::numeric_limits::max(), " bytes)"); + } + if (!in.SerializeToArray(buf, proto_size)) { + return InvalidArgument("Unable to serialize ", in.GetTypeName(), + " protocol buffer, perhaps the serialized size (", + proto_size, " bytes) is too large?"); + } out->data = buf; out->length = proto_size; out->data_deallocator = [](void* data, size_t length) { @@ -2108,7 +2123,7 @@ TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults( TF_Graph* graph, const TF_Buffer* graph_def, const TF_ImportGraphDefOptions* options, TF_Status* status) { GraphDef def; - if (!def.ParseFromArray(graph_def->data, graph_def->length)) { + if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, graph_def->length)) { status->status = InvalidArgument("Invalid GraphDef"); return nullptr; } @@ -2138,7 +2153,7 @@ void TF_GraphImportGraphDefWithReturnOutputs( return; } GraphDef def; - if (!def.ParseFromArray(graph_def->data, graph_def->length)) { + if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, graph_def->length)) { status->status = InvalidArgument("Invalid GraphDef"); return; } diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index f265da2c2c89c0e9caf14f2213c606fcb69997e0..93d07135e152d559d09018f78add9acd4b2cd2a3 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -54,7 +54,6 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", "//tensorflow/core/distributed_runtime/rpc:grpc_channel", - "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", @@ -93,10 +92,10 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime/eager:eager_client", "//tensorflow/core/distributed_runtime/eager:remote_tensor_handle", "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", - "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", ], ) @@ -139,7 +138,7 @@ tf_cuda_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", ], ) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 81221c4078bec9820ee187efdf0314da378be62b..6e4764bcbf218ae144306cf1ab04c6de4cd2d6c7 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -36,9 +36,9 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/execute.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" -#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/node_def_util.h" @@ -46,6 +46,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -147,46 +148,66 @@ tensorflow::Status CreateRemoteContexts( tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, TFE_Context** ctx) { + // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the + // server object (which currently CHECK-fails) and we miss the error, instead, + // we log the error, and then return to allow the user to see the error + // message. +#define LOG_AND_RETURN_IF_ERROR(...) \ + do { \ + const ::tensorflow::Status _status = (__VA_ARGS__); \ + LOG(ERROR) << _status.error_message(); \ + if (TF_PREDICT_FALSE(!_status.ok())) return _status; \ + } while (0) + string worker_name = tensorflow::strings::StrCat( "/job:", opts->server_def.job_name(), "/replica:0/task:", opts->server_def.task_index()); - std::unique_ptr server; - TF_RETURN_IF_ERROR( - tensorflow::eager::EagerGrpcServer::Create(opts->server_def, &server)); - TF_RETURN_IF_ERROR(server->Start()); + std::unique_ptr server; + LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(opts->server_def, &server)); + + tensorflow::GrpcServer* grpc_server = + dynamic_cast(server.get()); + if (grpc_server == nullptr) { + LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal( + "Currently, TFE_NewContext only supports tensorflow::GrpcServer.")); + } + + LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); std::vector remote_workers; - server->master_env()->worker_cache->ListWorkers(&remote_workers); + grpc_server->master_env()->worker_cache->ListWorkers(&remote_workers); remote_workers.erase( std::remove(remote_workers.begin(), remote_workers.end(), worker_name), remote_workers.end()); std::unique_ptr remote_device_mgr; - TF_RETURN_IF_ERROR(GetAllRemoteDevices( - remote_workers, server->master_env()->worker_cache, &remote_device_mgr)); + LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices( + remote_workers, grpc_server->master_env()->worker_cache, + &remote_device_mgr)); std::shared_ptr channel_cache = - server->channel_cache(); + grpc_server->channel_cache(); std::unique_ptr remote_eager_workers( tensorflow::eager::NewGrpcEagerClientCache(channel_cache)); // Initialize remote eager workers. tensorflow::gtl::FlatMap remote_contexts; - TF_RETURN_IF_ERROR(CreateRemoteContexts(remote_workers, - remote_eager_workers.get(), - opts->async, &remote_contexts)); + LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(remote_workers, + remote_eager_workers.get(), + opts->async, &remote_contexts)); tensorflow::RemoteRendezvous* r = - server->worker_env()->rendezvous_mgr->Find(0); + grpc_server->worker_env()->rendezvous_mgr->Find(0); - auto* device_mgr = server->worker_env()->device_mgr; + auto* device_mgr = grpc_server->worker_env()->device_mgr; *ctx = new TFE_Context(opts->session_options.options, opts->policy, opts->async, device_mgr, r, std::move(server), std::move(remote_eager_workers), std::move(remote_device_mgr), remote_contexts); return tensorflow::Status::OK(); +#undef LOG_AND_RETURN_IF_ERROR } } // namespace @@ -421,8 +442,11 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx, return ret; } -void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) { - op->operation.MutableAttrs()->Set(attr_name, value); +void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value, + size_t length) { + op->operation.MutableAttrs()->Set( + attr_name, + tensorflow::StringPiece(static_cast(value), length)); } void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) { @@ -473,16 +497,22 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, op->operation.MutableAttrs()->Set(attr_name, attr_value); } -#define TFE_OP_SET_ATTR_LIST(fn, type) \ - void fn(TFE_Op* op, const char* attr_name, const type* values, \ - int num_values) { \ - op->operation.MutableAttrs()->Set( \ - attr_name, \ - tensorflow::gtl::ArraySlice(values, num_values)); \ +void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, + const void* const* values, const size_t* lengths, + int num_values) { + std::vector v(num_values); + for (int i = 0; i < num_values; ++i) { + v[i] = tensorflow::StringPiece(static_cast(values[i]), + lengths[i]); } -TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList, char*) -TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float) -#undef TFE_OP_SET_ATTR_LIST + op->operation.MutableAttrs()->Set(attr_name, v); +} + +void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name, + const float* values, int num_values) { + op->operation.MutableAttrs()->Set( + attr_name, tensorflow::gtl::ArraySlice(values, num_values)); +} void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, const int64_t* values, int num_values) { @@ -655,9 +685,11 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value, const char* attr_name, TF_Status* status) { switch (default_value.value_case()) { - case tensorflow::AttrValue::kS: - TFE_OpSetAttrString(op, attr_name, default_value.s().data()); + case tensorflow::AttrValue::kS: { + const string& v = default_value.s(); + TFE_OpSetAttrString(op, attr_name, v.data(), v.size()); break; + } case tensorflow::AttrValue::kI: TFE_OpSetAttrInt(op, attr_name, static_cast(default_value.i())); break; diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 1862af3ce2f505a6e83b4805417eaf335ed07bc0..fdbd5374b2afe815c3a81b453930eb8f1fa351d3 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -278,7 +278,8 @@ TF_CAPI_EXPORT extern TF_AttrType TFE_OpNameGetAttrType( TF_CAPI_EXPORT extern void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, - const char* value); + const void* value, + size_t length); TF_CAPI_EXPORT extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value); TF_CAPI_EXPORT extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, @@ -305,7 +306,8 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunction(TFE_Op* op, TF_CAPI_EXPORT extern void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, - const char** value, + const void* const* values, + const size_t* lengths, int num_values); TF_CAPI_EXPORT extern void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 04a6efc47c5177c82b7e88168b67cc584587de7c..4c5077023d5bb3b83808bf3908e7110dd026e3ad 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -39,7 +39,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/distributed_runtime/remote_device.h" -#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" @@ -78,7 +78,7 @@ struct TFE_Context { TFE_ContextDevicePlacementPolicy default_policy, bool async, tensorflow::DeviceMgr* local_device_mgr, tensorflow::Rendezvous* rendezvous, - std::unique_ptr server, + std::unique_ptr server, std::unique_ptr remote_eager_workers, std::unique_ptr remote_device_mgr, const tensorflow::gtl::FlatMap& diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 27ff5f7211b0592637a173d337f93c10d376443f..cd035940ff8f15a084a49953bbbc2419a93e3541 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "tensorflow/c/eager/c_api_test_util.h" -#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -132,18 +132,20 @@ void TestRemoteExecute(bool async) { server_def.set_task_index(1); - std::unique_ptr worker_server; - ASSERT_TRUE( - tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server) - .ok()); + std::unique_ptr worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); ASSERT_TRUE(worker_server->Start().ok()); TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), status); - TFE_ContextOptionsSetAsync(opts, static_cast(1)); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_ContextOptionsSetAsync(opts, static_cast(1)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, + TFE_DEVICE_PLACEMENT_EXPLICIT); TFE_Context* ctx = TFE_NewContext(opts, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -205,6 +207,83 @@ void TestRemoteExecute(bool async) { TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); } TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); } +void TestRemoteExecuteSilentCopies(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + + std::unique_ptr worker_server; + ASSERT_TRUE(tensorflow::GrpcServer::Create( + server_def, tensorflow::Env::Default(), &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), + status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_ContextOptionsSetAsync(opts, static_cast(1)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); + TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + + // Handles are on task0, but op is on remote (task1). + TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task0); + TFE_OpSetDevice(matmul, remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + auto* retval_task0 = TFE_TensorHandleCopyToDevice( + retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(retval_task0); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + + TFE_DeleteTensorHandle(h0_task0); + TFE_DeleteTensorHandle(h1_task0); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(matmul); + + TFE_ContextAsyncWait(ctx, status); + TFE_DeleteContext(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_DeleteStatus(status); + + // TODO(nareshmodi): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); } +TEST(CAPI, RemoteExecuteSilentCopiesAsync) { + TestRemoteExecuteSilentCopies(true); +} + TEST(CAPI, TensorHandle) { TFE_TensorHandle* h = TestMatrixTensorHandle(); EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); @@ -1083,8 +1162,8 @@ TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value, if (TF_GetCode(status) != TF_OK) return nullptr; TFE_OpSetAttrType(op, "dtype", TF_FLOAT); TFE_OpSetAttrShape(op, "shape", {}, 0, status); - TFE_OpSetAttrString(op, "container", ""); - TFE_OpSetAttrString(op, "shared_name", ""); + TFE_OpSetAttrString(op, "container", "", 0); + TFE_OpSetAttrString(op, "shared_name", "", 0); if (TF_GetCode(status) != TF_OK) return nullptr; TFE_TensorHandle* var_handle = nullptr; int num_retvals = 1; diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index d6a4f141b6bb8ccadb77f1fa83b5fb742d78f70f..dfdef88945deca376368edd6f7aa322b1e1cbf94 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -273,6 +273,12 @@ string PrintAttrValue(const string& op, const AttrValue& attr_value) { return ""; // Prevent missing return warning } +bool IsEmptyList(const AttrValue::ListValue& list) { + return list.s_size() == 0 && list.i_size() == 0 && list.f_size() == 0 && + list.b_size() == 0 && list.type_size() == 0 && + list.shape_size() == 0 && list.tensor_size() == 0; +} + string ToCamelCase(const string& str) { string result; const char joiner = '_'; @@ -297,9 +303,9 @@ string ToCamelCase(const string& str) { // indicate whether to treat the type as const when accepting the C++ type as an // argument to a function. std::pair AttrTypeName(StringPiece attr_type) { - static const std::unordered_map, - StringPieceHasher> - attr_type_map{ + static const auto* attr_type_map = + new std::unordered_map, + StringPieceHasher>{ {"string", {"StringPiece", false}}, {"list(string)", {"gtl::ArraySlice", true}}, {"int", {"int64", false}}, @@ -317,14 +323,34 @@ std::pair AttrTypeName(StringPiece attr_type) { {"func", {"NameAttrList", true}}, }; - auto entry = attr_type_map.find(attr_type); - if (entry == attr_type_map.end()) { + auto entry = attr_type_map->find(attr_type); + if (entry == attr_type_map->end()) { LOG(FATAL) << "Unsupported Attr type: " << attr_type; return {"", false}; } return entry->second; } +const char* ListElementTypeName(StringPiece attr_type) { + static const auto* attr_list_type_map = + new std::unordered_map{ + {"list(string)", "string"}, + {"list(int)", "int"}, + {"list(float)", "float"}, + {"list(bool)", "bool"}, + {"list(type)", "DataType"}, + {"list(shape)", "PartialTensorShape"}, + {"list(tensor)", "TensorProto"}, + }; + + auto entry = attr_list_type_map->find(attr_type); + if (entry == attr_list_type_map->end()) { + LOG(FATAL) << "Unsupported or non-list Attr type: " << attr_type; + return ""; + } + return entry->second; +} + bool IsCPPKeyword(StringPiece name) { static const std::unordered_set // Keywords obtained from http://en.cppreference.com/w/cpp/keyword @@ -668,6 +694,7 @@ OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def, string OpInfo::GetOpAttrStruct() const { string struct_fields; string setters; + string defaults_static_storage; for (int i = 0; i < graph_op_def.attr_size(); ++i) { const auto& attr(graph_op_def.attr(i)); @@ -705,11 +732,32 @@ string OpInfo::GetOpAttrStruct() const { "_ = x;\n"); strings::StrAppend(&setters, " return ret;\n }\n\n"); - strings::StrAppend( - &struct_fields, " ", attr_type_name, " ", api_def_attr.rename_to(), - "_ = ", - PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()), - ";\n"); + string field_initiliazer; + auto& default_value = api_def_attr.default_value(); + if (default_value.value_case() == AttrValue::kList && + !IsEmptyList(default_value.list())) { + // Non-empty lists need static storage for their defaults. Define a + // function with static local variable that stores the array. + strings::StrAppend(&defaults_static_storage, " static ", + attr_type_name, " Default_", api_def_attr.rename_to(), + "() {\n"); + strings::StrAppend( + &defaults_static_storage, " static const ", + ListElementTypeName(attr.type()), " kStorage[] = ", + PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()), + ";\n"); + strings::StrAppend(&defaults_static_storage, " return ", + attr_type_name, "(kStorage);\n }\n"); + // Set the field_initializer to call the defined function. + strings::StrAppend(&field_initiliazer, "Default_", + api_def_attr.rename_to(), "()"); + } else { + field_initiliazer = + PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()); + } + strings::StrAppend(&struct_fields, " ", attr_type_name, " ", + api_def_attr.rename_to(), "_ = ", field_initiliazer, + ";\n"); } if (struct_fields.empty()) { @@ -721,6 +769,9 @@ string OpInfo::GetOpAttrStruct() const { string struct_decl = MakeComment(attrs_comment, " "); strings::StrAppend(&struct_decl, " struct Attrs {\n"); strings::StrAppend(&struct_decl, setters, struct_fields); + if (!defaults_static_storage.empty()) { + strings::StrAppend(&struct_decl, " private:\n", defaults_static_storage); + } strings::StrAppend(&struct_decl, " };\n"); return struct_decl; diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 0025842aead53973befc794378a26fa8db2ae1cb..28070d60dbbe6dd8f930b8e6509cedcf09f94e11 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -287,7 +287,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config)); const int64 result_index = compile_result.aot->result_buffer_index(); const xla::BufferSizes& temp_sizes = compile_result.aot->buffer_sizes(); - if (result_index < 0 || result_index > temp_sizes.size()) { + if (result_index < 0 || result_index >= temp_sizes.size()) { return errors::InvalidArgument("result index: ", result_index, " is outside the range of temp sizes: [0,", temp_sizes.size(), ")"); diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index ab8cd8f4bcd3b5a102692b47cfedfce6a9d9cc47..a92218b1292a436165df16808c58999534e5ba3b 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -181,6 +181,7 @@ cc_library( "//tensorflow/core/kernels:no_op", "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:sendrecv_ops", + "//tensorflow/core/kernels:shape_ops", "//tensorflow/core/kernels:variable_ops", ], ) @@ -316,11 +317,11 @@ cc_library( ":xla_cluster_util", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/kernels:parallel_check_op", - "//tensorflow/compiler/jit/legacy_flags:encapsulate_subgraphs_pass_flags", "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/compiler/jit/ops:parallel_check_op", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/tf2xla:validate_control_flow", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", @@ -342,6 +343,7 @@ cc_library( "//tensorflow/compiler/jit/graphcycles", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", ], ) diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc index 731b8ebfdc6262500940274c94a03ae7c0376096..a2e6285339f9ed0bde8d72f5b4752b1ecc22f426 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op.cc @@ -66,8 +66,28 @@ class SinglePassSearch { Status CompilationRequested(const FunctionLibraryRuntime& flr, const NodeDef& node_def) { + const FunctionDef* function_def = + flr.GetFunctionLibraryDefinition()->Find(node_def.name()); + if (function_def == nullptr) { + // The node def is not calling a function. Individual ops can be + // run directly using on-demand mode, no need to create XlaLaunch + // kernel for them. + // TODO(b/110359382): Make custom kernel creation return a bool instead of + // status. + // We don't set error messages here to avoid unnecessary string copy. + // Similarly below. + return Status(error::INVALID_ARGUMENT, ""); + } + + // If kXlaCompileAttr is set on the node_def, use its value. + const auto& it = node_def.attr().find(kXlaCompileAttr); + if (it != node_def.attr().end()) { + return it->second.b() ? Status::OK() : Status(error::INVALID_ARGUMENT, ""); + } + + // kXlaCompileAttr is not set on node_def, check if it is set on + // FunctionDef. bool xla_compile = false; - // Check if op is marked _XlaCompile=true. Status status = flr.GetFunctionLibraryDefinition()->GetAttr( node_def, kXlaCompileAttr, &xla_compile); if (!status.ok() || !xla_compile) { diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 6d1e3325ebd35b9608ea273fb7de39bad381e60d..b78c30c21578b01d1c70b44046326f300d88bb02 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -23,11 +23,11 @@ limitations under the License. #include #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" -#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/validate_control_flow.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/optimization_registry.h" @@ -107,41 +107,11 @@ void MarkGuaranteedConstants( } } -// A node/slot pair. -// TODO(phawkins): is there a common definition of this? -struct NodeSlot { - NodeSlot() : node(nullptr), slot(-1), dtype(DT_INVALID) {} - NodeSlot(const Node* node, int slot) - : node(node), slot(slot), dtype(DT_INVALID) {} - NodeSlot(const Node* node, int slot, DataType dtype) - : node(node), slot(slot), dtype(dtype) {} - - const Node* node; - int slot; - - // Optional: used to record the destination type of a source NodeSlot in case - // the source output is a Ref type that is cast to a Tensor at the - // destination. - DataType dtype; - - bool operator==(const NodeSlot& other) const { - return node == other.node && slot == other.slot && dtype == other.dtype; - } - - // Leave dtype out of the hash since there are never two NodeSlots with the - // same node and slot and different dtypes. - struct Hasher { - uint64 operator()(NodeSlot const& s) const { - return Hash64Combine(std::hash()(s.node), - std::hash()(s.slot)); - } - }; - - struct PairHasher { - uint64 operator()(std::pair const& s) const { - return Hash64Combine(Hasher()(s.first), Hasher()(s.second)); - } - }; +struct OutputInputTensorPairHasher { + uint64 operator()(std::pair const& s) const { + return Hash64Combine(OutputTensor::Hash()(s.first), + InputTensor::Hash()(s.second)); + } }; // TODO(phawkins) add a canonical copy of these operator names and refactor @@ -182,8 +152,7 @@ class Encapsulator { // Write a copy of the input graph to 'graph_out', where the subgraphs are // replaced with calls to the new functions. - Status BuildOutputGraph(bool parallel_checking, Graph* graph_out, - FunctionLibraryDefinition* library); + Status BuildOutputGraph(Graph* graph_out, FunctionLibraryDefinition* library); private: // A subgraph of the input, all marked with a common 'group_attribute' @@ -271,7 +240,7 @@ class Encapsulator { // Adds the function call node to graph_out. Status AddFunctionCallNode( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out); + Graph* graph_out); // Adds _RecvAtHost and _SendFromHost nodes, where needed, to graph_out. Status AddOutsideCompilationHostIONodes( @@ -284,11 +253,9 @@ class Encapsulator { // Subgraph. void GetOutsideCompilationSubgraphNames(std::vector* names) const; - // Returns the Node that inputs to the function should be wired up to. - Node* GetCallNodeForInputs() const; - - // Returns the Node that outputs to the function should be wired up to. - Node* GetCallNodeForOutputs() const; + // Returns the Node that the inputs and outputs of the function should be + // wired up to. + Node* GetCallNode() const; // Returns the index of the arg that the dst of edge should connect to. int GetArgIndexForEdge(const Edge* edge) const; @@ -380,7 +347,7 @@ class Encapsulator { // Map from source (producer node/slot) tensors in the original graph to // input index (slot number in the HostCompute/RecvAtHost nodes that will // be created) for the outside_compilation subgraph. - std::unordered_map inputs; + std::unordered_map inputs; // Set of nodes in the original graph that are the source of control edges // that cross from the containing compiled subgraph into the @@ -396,8 +363,15 @@ class Encapsulator { // node/slot) tensors in the original graph to output index (slot number // in the SendFromHost/HostCompute nodes that will be created) for the // outside_compilation subgraph. - std::unordered_map outputs_by_src; - std::unordered_map outputs_by_dst; + struct ArgNumAndType { + int index; + DataType dtype; + + ArgNumAndType(int i, DataType t) : index(i), dtype(t) {} + }; + std::unordered_map + outputs_by_src; + std::unordered_map outputs_by_dst; // Set of nodes in the original graph that are the destination of control // edges that cross from the outside_compilation subgraph into the @@ -425,12 +399,6 @@ class Encapsulator { OutsideCompilationSubgraph* LookupOrCreateOutsideCompilationSubgraph( const string& outside_compilation_id); - // Builds a ParallelCheck op that compares the output of the original - // subgraph with the encapsulated subgraph. - Status BuildParallelCheckOp( - const std::unordered_map& node_images, - Graph* graph_out); - // Builds a placeholder node used to provide the key input to a RecvAtHost // or SendFromHost node. This placeholder node will be removed by a later // pass. @@ -482,26 +450,21 @@ class Encapsulator { // Not owned. Node* host_compute_key_placeholder_ = nullptr; - // Function call node(s) in the output graph. Not owned. - // If parallel_checking is enabled, 'call_node_inputs' is the function call - // node to which inputs should be fed, and 'call_node_outputs' is the - // parallel check op from which outputs should be read. If parallel checking - // is disabled, both point to the function call node. - Node* call_node_inputs_; - Node* call_node_outputs_; + // Function call node in the output graph. Not owned. + Node* call_node_; // Maps from source (producer node/slot) and destination // (consumer node/slot) tensors in the input graph to _Arg numbers in // the subgraph. The source map is one-to-one, whereas the dest map may be // many-to-one. - std::unordered_map args_by_src_; - std::unordered_map args_by_dst_; + std::unordered_map args_by_src_; + std::unordered_map args_by_dst_; - // The _Arg nodes in the subgraph, in order by argument number. + // The arguments to the subgraph, in order. std::vector args_; // Map from source tensor in the input graph to result #. - std::unordered_map results_; + std::unordered_map results_; // The outside_compilation clusters in this subgraph. std::unordered_map @@ -541,13 +504,12 @@ class Encapsulator { // Copies all nodes that aren't in a compiled subgraph to the output graph. Status CopyNodesToOutputGraph( - bool parallel_checking, Graph* graph_out, - std::unordered_map* node_images); + Graph* graph_out, std::unordered_map* node_images); // Adds function call nodes for each compiled subgraph. Status AddFunctionCallNodes( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out); + Graph* graph_out); // Adds _RecvAtHost and _SendFromHost nodes, where needed, for all // outside_compilation subgraphs. @@ -598,9 +560,9 @@ class Encapsulator { const string& src_outside_compilation_id, const string& dst_func_id, const string& dst_outside_compilation_id, const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out, - std::unordered_set, NodeSlot::PairHasher>* - edges_added); + Graph* graph_out, + std::unordered_set, + OutputInputTensorPairHasher>* edges_added); // Adds control dependencies between subgraph call nodes that have // dependencies via outside_compilation edges. @@ -609,7 +571,7 @@ class Encapsulator { // Adds all edges to the output graph. Status AddEdgesToOutputGraph( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out); + Graph* graph_out); // Constructs a minimal shape inference graph that can be used to determine // the shape of send_node at the time that the subgraph is compiled. @@ -729,20 +691,14 @@ void TopologicalClusterSort( } // namespace -Node* Encapsulator::Subgraph::GetCallNodeForInputs() const { - return call_node_inputs_; -} - -Node* Encapsulator::Subgraph::GetCallNodeForOutputs() const { - return call_node_outputs_; -} +Node* Encapsulator::Subgraph::GetCallNode() const { return call_node_; } int Encapsulator::Subgraph::GetArgIndexForEdge(const Edge* edge) const { - return args_by_dst_.at(NodeSlot(edge->dst(), edge->dst_input())); + return args_by_dst_.at(InputTensor(edge->dst(), edge->dst_input())); } int Encapsulator::Subgraph::GetResultIndexForEdge(const Edge* edge) const { - return results_.at(NodeSlot(edge->src(), edge->src_output())); + return results_.at(OutputTensor(edge->src(), edge->src_output())); } Node* Encapsulator::Subgraph::GetRecvAtHostNode( @@ -754,7 +710,7 @@ Node* Encapsulator::Subgraph::GetRecvAtHostNode( int Encapsulator::Subgraph::GetRecvAtHostSlot( const string& outside_compilation_subgraph_name, const Edge* edge) const { return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) - .inputs.at(NodeSlot(edge->src(), edge->src_output())); + .inputs.at(OutputTensor(edge->src(), edge->src_output())); } Node* Encapsulator::Subgraph::GetSendFromHostNode( @@ -766,7 +722,7 @@ Node* Encapsulator::Subgraph::GetSendFromHostNode( int Encapsulator::Subgraph::GetSendFromHostSlot( const string& outside_compilation_subgraph_name, const Edge* edge) const { return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) - .outputs_by_dst.at(NodeSlot(edge->dst(), edge->dst_input())); + .outputs_by_dst.at(InputTensor(edge->dst(), edge->dst_input())); } Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) { @@ -791,10 +747,10 @@ Status Encapsulator::Subgraph::RecordArg( std::vector>* src_arg_pairs) { Node* src_node = edge->src(); int src_slot = edge->src_output(); - std::unordered_map::iterator iter; + std::unordered_map::iterator iter; bool inserted; - std::tie(iter, inserted) = - args_by_src_.emplace(NodeSlot(src_node, src_slot), args_by_src_.size()); + std::tie(iter, inserted) = args_by_src_.emplace( + OutputTensor(src_node, src_slot), args_by_src_.size()); int arg_index = iter->second; if (inserted) { NodeDef arg_def; @@ -815,7 +771,7 @@ Status Encapsulator::Subgraph::RecordArg( Node* dst_node = edge->dst(); Node* dst_image = node_images.at(dst_node); int dst_slot = edge->dst_input(); - args_by_dst_[NodeSlot(dst_node, dst_slot)] = arg_index; + args_by_dst_[InputTensor(dst_node, dst_slot)] = arg_index; graph_->AddEdge(args_[arg_index], 0, dst_image, dst_slot); return Status::OK(); } @@ -826,10 +782,10 @@ Status Encapsulator::Subgraph::RecordResult( Node* src_node = edge->src(); Node* src_image = node_images.at(src_node); int src_slot = edge->src_output(); - std::unordered_map::iterator iter; + std::unordered_map::iterator iter; bool inserted; std::tie(iter, inserted) = - results_.emplace(NodeSlot(src_node, src_slot), results_.size()); + results_.emplace(OutputTensor(src_node, src_slot), results_.size()); int ret_index = iter->second; if (inserted) { NodeDef ret_def; @@ -867,8 +823,8 @@ void Encapsulator::Subgraph::RecordOutsideCompilationInputOrControl( outside_subgraph->control_inputs.insert(edge->src()); } else { int input_index = outside_subgraph->inputs.size(); - outside_subgraph->inputs.emplace(NodeSlot(edge->src(), edge->src_output()), - input_index); + outside_subgraph->inputs.emplace( + OutputTensor(edge->src(), edge->src_output()), input_index); } } @@ -882,11 +838,13 @@ void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl( DataType dtype = edge->dst()->input_type(edge->dst_input()); auto output_iter = outside_subgraph->outputs_by_src - .emplace(NodeSlot(edge->src(), edge->src_output(), dtype), - outside_subgraph->outputs_by_src.size()) + .emplace(OutputTensor(edge->src(), edge->src_output()), + OutsideCompilationSubgraph::ArgNumAndType( + outside_subgraph->outputs_by_src.size(), dtype)) .first; - int output_index = output_iter->second; - outside_subgraph->outputs_by_dst[NodeSlot(edge->dst(), edge->dst_input())] = + const int output_index = output_iter->second.index; + outside_subgraph + ->outputs_by_dst[InputTensor(edge->dst(), edge->dst_input())] = output_index; } } @@ -968,7 +926,7 @@ Status Encapsulator::Subgraph::AddHostComputes( for (const auto& input_src : oc_subgraph.inputs) { const Node* src_node = input_src.first.node; Node* src_image = node_images.at(src_node); - int src_slot = input_src.first.slot; + int src_slot = input_src.first.index; int input_index = input_src.second; DataType dtype = src_node->output_type(src_slot); @@ -976,8 +934,8 @@ Status Encapsulator::Subgraph::AddHostComputes( input_dtypes[input_index] = dtype; } for (const auto& output : oc_subgraph.outputs_by_src) { - DataType dtype = output.first.dtype; - int output_index = output.second; + DataType dtype = output.second.dtype; + int output_index = output.second.index; output_dtypes[output_index] = dtype; } @@ -1015,7 +973,7 @@ Status Encapsulator::Subgraph::AddHostComputes( for (auto& input_src : oc_subgraph.inputs) { const Node* src_node = input_src.first.node; Node* src_image = node_images.at(src_node); - int src_slot = input_src.first.slot; + int src_slot = input_src.first.index; int input_index = input_src.second; graph_->AddEdge(src_image, src_slot, host_compute, input_index); } @@ -1037,7 +995,7 @@ Status Encapsulator::Subgraph::AddHostComputes( for (const auto& output : oc_subgraph.outputs_by_dst) { const Node* dst_node = output.first.node; Node* dst_image = node_images.at(dst_node); - int dst_slot = output.first.slot; + int dst_slot = output.first.index; int output_index = output.second; graph_->AddEdge(host_compute, output_index, dst_image, dst_slot); @@ -1075,7 +1033,7 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, void Encapsulator::Subgraph::ConnectSequencerToCallNode(Graph* graph_out) { if (sequencer_ != nullptr) { VLOG(2) << "ConnectSequencerToCallNode"; - graph_out->AddControlEdge(sequencer_, call_node_inputs_); + graph_out->AddControlEdge(sequencer_, call_node_); } } @@ -1090,14 +1048,19 @@ Status Encapsulator::Subgraph::BuildFunctionDef( call_node_def_.set_device(device_); if (rewrite_subgraph_fn) { + std::vector arg_source_tensors(args_by_src_.size()); + for (const auto& arg : args_by_src_) { + arg_source_tensors.at(arg.second) = arg.first; + } // Initialize the input and output permutations to the identity. std::vector input_permutation(args_by_src_.size()); std::iota(input_permutation.begin(), input_permutation.end(), 0); std::vector output_permutation(results_.size()); std::iota(output_permutation.begin(), output_permutation.end(), 0); - TF_RETURN_IF_ERROR(rewrite_subgraph_fn( - &graph_, &input_permutation, &output_permutation, &call_node_def_)); + TF_RETURN_IF_ERROR( + rewrite_subgraph_fn(arg_source_tensors, &graph_, &input_permutation, + &output_permutation, &call_node_def_)); // Apply the input/output permutations to the 'args_by_...' and 'results_' // mappings, so when we build edges in BuildOutputGraph() we @@ -1200,83 +1163,16 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef( return Status::OK(); } -Status Encapsulator::Subgraph::BuildParallelCheckOp( - const std::unordered_map& node_images, - Graph* graph_out) { - // Build an index mapping output positions to node/slot pairs in the - // original graph. - std::vector results_by_num(results_.size()); - for (const auto& entry : results_) { - results_by_num[entry.second] = entry.first; - } - - // Build a parallel check NodeDef. - int num_results = results_by_num.size(); - std::vector result_dtypes(num_results); - std::vector expected_outputs(num_results); - std::vector actual_outputs(num_results); - for (int i = 0; i < num_results; ++i) { - const NodeSlot& node_slot = results_by_num[i]; - result_dtypes[i] = node_slot.node->output_type(node_slot.slot); - expected_outputs[i] = - NodeDefBuilder::NodeOut(node_images.at(node_slot.node)->name(), - node_slot.slot, result_dtypes[i]); - actual_outputs[i] = - NodeDefBuilder::NodeOut(call_node_def_.name(), i, result_dtypes[i]); - } - // Assign the parallel check op to a CPU on the same task as the cluster it is - // checking. - string device, dummy; - if (!DeviceNameUtils::SplitDeviceName( - call_node_inputs_->assigned_device_name(), &device, &dummy)) { - return errors::InvalidArgument("Could not parse device name"); - } - strings::StrAppend(&device, "/cpu:0"); - - NodeDef check_def; - TF_RETURN_IF_ERROR( - NodeDefBuilder(graph_out->NewName(strings::StrCat(call_node_def_.name(), - "_parallel_check")), - "ParallelCheck") - .Device(device) - .Attr("T", result_dtypes) - .Input(expected_outputs) - .Input(actual_outputs) - .Finalize(&check_def)); - - Status s; - Node* check_op = graph_out->AddNode(check_def, &s); - if (!s.ok()) return s; - check_op->set_assigned_device_name(device); - - // TODO(phawkins): it seems redundant to call AddEdge as well as - // pass Inputs to the NodeDefBuilder, but I have been unable to find a - // way to avoid it. - for (int i = 0; i < num_results; ++i) { - const NodeSlot& node_slot = results_by_num[i]; - graph_out->AddEdge(node_images.at(node_slot.node), node_slot.slot, check_op, - i); - graph_out->AddEdge(call_node_inputs_, i, check_op, num_results + i); - } - - call_node_outputs_ = check_op; - return Status::OK(); -} - Status Encapsulator::Subgraph::AddFunctionCallNode( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out) { + Graph* graph_out) { Status s; - call_node_inputs_ = graph_out->AddNode(call_node_def_, &s); + call_node_ = graph_out->AddNode(call_node_def_, &s); if (!s.ok()) return s; // Copy the assigned device and the key_annotation over. - call_node_inputs_->set_assigned_device_name(device_); - call_node_outputs_ = call_node_inputs_; + call_node_->set_assigned_device_name(device_); - if (parallel_checking) { - TF_RETURN_IF_ERROR(BuildParallelCheckOp(node_images, graph_out)); - } return Status::OK(); } @@ -1315,7 +1211,7 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( for (const auto& input : oc_subgraph->inputs) { const Node* src_node = input.first.node; - int src_slot = input.first.slot; + int src_slot = input.first.index; int input_index = input.second; DataType dtype = src_node->output_type(src_slot); @@ -1369,8 +1265,8 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( for (const auto& output : oc_subgraph->outputs_by_src) { const Node* src_node = output.first.node; Node* src_image = node_images.at(src_node); - int src_slot = output.first.slot; - int output_index = output.second; + int src_slot = output.first.index; + int output_index = output.second.index; DataType dtype = src_node->output_type(src_slot); dtypes[output_index] = dtype; @@ -1609,6 +1505,11 @@ Status Encapsulator::SplitIntoSubgraphs() { for (auto& entry : subgraphs_) { Subgraph& subgraph = entry.second; FixupSourceAndSinkEdges(subgraph.GetGraph()); + // Verify that the graph has well-formed control flow structure to be + // functionalized. + std::vector dummy; + TF_RETURN_IF_ERROR( + BuildAndValidateControlFlowInfo(subgraph.GetGraph(), &dummy)); } return s; @@ -1627,27 +1528,17 @@ Status Encapsulator::BuildFunctionDefs( } Status Encapsulator::CopyNodesToOutputGraph( - bool parallel_checking, Graph* graph_out, - std::unordered_map* node_images) { + Graph* graph_out, std::unordered_map* node_images) { for (Node* node : graph_in_->op_nodes()) { string func_id; string outside_compilation_id; TF_RETURN_IF_ERROR( GetFunctionNameAttr(node, &func_id, &outside_compilation_id)); - // Don't copy nodes that going to be encapsulated, unless parallel checking - // is enabled. - if (IsInSubgraph(func_id, outside_compilation_id) && !parallel_checking) - continue; + // Don't copy nodes that are going to be encapsulated. + if (IsInSubgraph(func_id, outside_compilation_id)) continue; Node* image = graph_out->CopyNode(node); - if (!outside_compilation_id.empty()) { - if (parallel_checking) { - return errors::InvalidArgument( - "Parallel checking is not supported when outside_compilation " - "clusters are present."); - } - } (*node_images)[node] = image; } (*node_images)[graph_in_->source_node()] = graph_out->source_node(); @@ -1657,10 +1548,10 @@ Status Encapsulator::CopyNodesToOutputGraph( Status Encapsulator::AddFunctionCallNodes( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out) { + Graph* graph_out) { for (auto& subgraph_entry : subgraphs_) { - TF_RETURN_IF_ERROR(subgraph_entry.second.AddFunctionCallNode( - node_images, parallel_checking, graph_out)); + TF_RETURN_IF_ERROR( + subgraph_entry.second.AddFunctionCallNode(node_images, graph_out)); } return Status::OK(); } @@ -1694,7 +1585,7 @@ Status Encapsulator::FindOutputImageOfEdgeSrc( } else { // The edge is from a subgraph to a regular node in the output graph so // use the subgraph's call node output. - *src_image = subgraphs_.at(src_func_id).GetCallNodeForOutputs(); + *src_image = subgraphs_.at(src_func_id).GetCallNode(); } } else { // The source of the edge is in the output graph so use the node image in @@ -1742,7 +1633,7 @@ Status Encapsulator::FindOutputImageOfEdgeDst( } else { // The edge is to a subgraph from a regular node in the output graph so // use the subgraph's call node input. - *dst_image = subgraphs_.at(dst_func_id).GetCallNodeForInputs(); + *dst_image = subgraphs_.at(dst_func_id).GetCallNode(); } } else { // The destination of the edge is in the output graph so use the node image @@ -1778,10 +1669,9 @@ Status Encapsulator::CopyEdgeToOutputGraph( const Edge* edge, const string& src_func_id, const string& src_outside_compilation_id, const string& dst_func_id, const string& dst_outside_compilation_id, - const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out, - std::unordered_set, NodeSlot::PairHasher>* - edges_added) { + const std::unordered_map& node_images, Graph* graph_out, + std::unordered_set, + OutputInputTensorPairHasher>* edges_added) { Node* src_image; TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc( src_func_id, src_outside_compilation_id, dst_func_id, @@ -1796,16 +1686,12 @@ Status Encapsulator::CopyEdgeToOutputGraph( if (edge->IsControlEdge()) { // Add the control edge, if we have not already added it, using the images // determined above (potentially call operators or RecvAtHost/SendFromHost). - if (edges_added->emplace(NodeSlot(src_image, -1), NodeSlot(dst_image, -1)) + if (edges_added + ->emplace(OutputTensor(src_image, -1), InputTensor(dst_image, -1)) .second) { graph_out->AddControlEdge(src_image, dst_image); } - // If parallel checking is enabled, also add a control edge to the - // corresponding parallel check op. - if (parallel_checking) { - graph_out->AddControlEdge(src_image, node_images.at(edge->dst())); - } return Status::OK(); } @@ -1817,18 +1703,10 @@ Status Encapsulator::CopyEdgeToOutputGraph( FindOutputSlotOfEdgeDst(src_func_id, src_outside_compilation_id, dst_func_id, dst_outside_compilation_id, edge); - if (IsInSubgraph(dst_func_id, dst_outside_compilation_id) && - parallel_checking) { - // If we are parallel checking, also feed the tensor as an input to the - // corresponding parallel check subgraph. - graph_out->AddEdge(src_image, src_output, node_images.at(edge->dst()), - edge->dst_input()); - } - // Add the edge, if we have not already added it. if (edges_added - ->emplace(NodeSlot(src_image, src_output), - NodeSlot(dst_image, dst_input)) + ->emplace(OutputTensor(src_image, src_output), + InputTensor(dst_image, dst_input)) .second) { graph_out->AddEdge(src_image, src_output, dst_image, dst_input); } @@ -1839,8 +1717,8 @@ Status Encapsulator::AddCallNodeDependencies(Graph* graph_out) { for (const auto& ancestors : subgraph_ancestors_) { const string& subgraph = ancestors.first; for (const string& ancestor : ancestors.second) { - graph_out->AddControlEdge(subgraphs_[ancestor].GetCallNodeForOutputs(), - subgraphs_[subgraph].GetCallNodeForInputs()); + graph_out->AddControlEdge(subgraphs_[ancestor].GetCallNode(), + subgraphs_[subgraph].GetCallNode()); } } return Status::OK(); @@ -1848,11 +1726,12 @@ Status Encapsulator::AddCallNodeDependencies(Graph* graph_out) { Status Encapsulator::AddEdgesToOutputGraph( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out) { + Graph* graph_out) { // Set of edges already added to the output graph, represented as (src, dst) // pairs. We use the set to deduplicate edges; multiple edges in the input // graph may map to one edge in the output graph. - std::unordered_set, NodeSlot::PairHasher> + std::unordered_set, + OutputInputTensorPairHasher> edges_added; for (const Edge* edge : graph_in_->edges()) { @@ -1870,16 +1749,6 @@ Status Encapsulator::AddEdgesToOutputGraph( if (IsInSubgraph(src_func_id, src_outside_compilation_id) && IsInSubgraph(dst_func_id, dst_outside_compilation_id) && src_func_id == dst_func_id) { - if (parallel_checking) { - Node* src_image = node_images.at(edge->src()); - Node* dst_image = node_images.at(edge->dst()); - if (edge->IsControlEdge()) { - graph_out->AddControlEdge(src_image, dst_image); - } else { - graph_out->AddEdge(src_image, edge->src_output(), dst_image, - edge->dst_input()); - } - } continue; } @@ -1887,8 +1756,7 @@ Status Encapsulator::AddEdgesToOutputGraph( // unclustered graph. TF_RETURN_IF_ERROR(CopyEdgeToOutputGraph( edge, src_func_id, src_outside_compilation_id, dst_func_id, - dst_outside_compilation_id, node_images, parallel_checking, graph_out, - &edges_added)); + dst_outside_compilation_id, node_images, graph_out, &edges_added)); } for (auto& subgraph_entry : subgraphs_) { @@ -2504,18 +2372,15 @@ Status Encapsulator::GetShapeInfoForOutsideCompilationSends( return Status::OK(); } -Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out, +Status Encapsulator::BuildOutputGraph(Graph* graph_out, FunctionLibraryDefinition* library) { // Map from nodes in the input graph to nodes in the output graph. std::unordered_map node_images; - TF_RETURN_IF_ERROR( - CopyNodesToOutputGraph(parallel_checking, graph_out, &node_images)); - TF_RETURN_IF_ERROR( - AddFunctionCallNodes(node_images, parallel_checking, graph_out)); + TF_RETURN_IF_ERROR(CopyNodesToOutputGraph(graph_out, &node_images)); + TF_RETURN_IF_ERROR(AddFunctionCallNodes(node_images, graph_out)); TF_RETURN_IF_ERROR(AddOutsideCompilationHostIONodes(node_images, graph_out)); - TF_RETURN_IF_ERROR( - AddEdgesToOutputGraph(node_images, parallel_checking, graph_out)); + TF_RETURN_IF_ERROR(AddEdgesToOutputGraph(node_images, graph_out)); TF_RETURN_IF_ERROR( GetShapeInfoForOutsideCompilationSends(graph_out, library)); @@ -2528,8 +2393,8 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out, Status EncapsulateSubgraphsInFunctions( string group_attribute, string outside_compilation_attribute, const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn, - bool parallel_checking, bool reuse_existing_functions, - std::unique_ptr* graph_out, FunctionLibraryDefinition* library) { + bool reuse_existing_functions, std::unique_ptr* graph_out, + FunctionLibraryDefinition* library) { Status s; Encapsulator encapsulator(std::move(group_attribute), @@ -2543,8 +2408,7 @@ Status EncapsulateSubgraphsInFunctions( std::unique_ptr out(new Graph(library)); out->set_versions(graph_in.versions()); - TF_RETURN_IF_ERROR( - encapsulator.BuildOutputGraph(parallel_checking, out.get(), library)); + TF_RETURN_IF_ERROR(encapsulator.BuildOutputGraph(out.get(), library)); *graph_out = std::move(out); return Status::OK(); @@ -2585,8 +2449,6 @@ static Status RenumberArguments(Graph* graph, Status EncapsulateSubgraphsPass::Run( const GraphOptimizationPassOptions& options) { VLOG(1) << "EncapsulateSubgraphsPass::Run"; - legacy_flags::EncapsulateSubgraphsPassFlags* flags = - legacy_flags::GetEncapsulateSubgraphsPassFlags(); if (VLOG_IS_ON(1)) { dump_graph::DumpGraphToFile("before_encapsulate_subgraphs", **options.graph, options.flib_def); @@ -2602,69 +2464,73 @@ Status EncapsulateSubgraphsPass::Run( FunctionLibraryRuntime* flr = pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); - auto rewrite_subgraph = [flr](std::unique_ptr* subgraph, - std::vector* input_permutation, - std::vector* output_permutation, - NodeDef* node) { - // Optimize the subgraph. - OptimizeGraph(flr, subgraph); - - const int num_args = input_permutation->size(); - std::vector const_args(num_args); - TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args)); - - DataTypeVector arg_types(num_args); - TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types)); - - // Compute a permutation of the arguments such that the constant arguments - // are first. - const int num_consts = - std::count(const_args.begin(), const_args.end(), true); - - const int num_resources = - std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE); - const int num_nonconsts = num_args - num_resources - num_consts; - if (num_nonconsts < 0) { - return errors::Internal("num_nonconsts should be >= 0, was ", - num_nonconsts); - } - - int const_pos = 0; - int arg_pos = num_consts; - int resource_pos = num_consts + num_nonconsts; - for (int i = 0; i < num_args; ++i) { - if (const_args[i]) { - if (arg_types[i] == DT_RESOURCE) { - return errors::Internal( - "Resource arguments cannot be constant (argument ", i, ")"); + auto rewrite_subgraph = + [flr](const std::vector& arg_source_tensors, + std::unique_ptr* subgraph, + std::vector* input_permutation, + std::vector* output_permutation, NodeDef* node) { + // Optimize the subgraph. + OptimizeGraph(flr, subgraph); + + const int num_args = input_permutation->size(); + std::vector const_args(num_args); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args)); + + DataTypeVector arg_types(num_args); + TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types)); + + // Compute a permutation of the arguments such that the constant + // arguments are first. + const int num_consts = + std::count(const_args.begin(), const_args.end(), true); + + const int num_resources = + std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE); + const int num_nonconsts = num_args - num_resources - num_consts; + if (num_nonconsts < 0) { + return errors::Internal("num_nonconsts should be >= 0, was ", + num_nonconsts); } - (*input_permutation)[i] = const_pos; - ++const_pos; - } else if (arg_types[i] == DT_RESOURCE) { - (*input_permutation)[i] = resource_pos; - ++resource_pos; - } else { - (*input_permutation)[i] = arg_pos; - ++arg_pos; - } - } - // Renumber argument nodes in the graph. - TF_RETURN_IF_ERROR(RenumberArguments(subgraph->get(), *input_permutation)); - - // TODO(phawkins): add a forward is-constant analysis, similarly split - // outputs into host-memory constants and device-memory non-constants. - - AddNodeAttr(kXlaCompiledKernelAttr, true, node); - AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node); - AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node); - return Status::OK(); - }; + int const_pos = 0; + int arg_pos = num_consts; + int resource_pos = num_consts + num_nonconsts; + for (int i = 0; i < num_args; ++i) { + if (const_args[i]) { + if (arg_types[i] == DT_RESOURCE) { + return errors::Internal( + "Resource arguments cannot be constant (argument ", i, ")"); + } + (*input_permutation)[i] = const_pos; + ++const_pos; + } else if (arg_types[i] == DT_RESOURCE) { + (*input_permutation)[i] = resource_pos; + ++resource_pos; + } else { + (*input_permutation)[i] = arg_pos; + ++arg_pos; + } + } - TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions( - kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph, - rewrite_subgraph, flags->tf_xla_parallel_checking, - /*reuse_existing_functions=*/false, &graph_out, library)); + // Renumber argument nodes in the graph. + TF_RETURN_IF_ERROR( + RenumberArguments(subgraph->get(), *input_permutation)); + + // TODO(phawkins): add a forward is-constant analysis, similarly split + // outputs into host-memory constants and device-memory non-constants. + + AddNodeAttr(kXlaCompiledKernelAttr, true, node); + AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node); + AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node); + return Status::OK(); + }; + + TF_RETURN_WITH_CONTEXT_IF_ERROR( + EncapsulateSubgraphsInFunctions( + kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph, + rewrite_subgraph, /*reuse_existing_functions=*/false, &graph_out, + library), + "EncapsulateSubgraphsPass failed"); if (VLOG_IS_ON(1)) { dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out, diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 5fee36f022a7515504cb6faa5cca658481b784c5..926589546fec72048485d30966f31b24e44b1245 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -28,6 +28,9 @@ limitations under the License. namespace tensorflow { // A rewriting function to apply to each subgraph during encapsulation. +// 'arg_source_tensors' are the tensors corresponding to the arguments in the +// original source graph (*not* 'graph'). +// // 'graph' is the subgraph. The rewriting may renumber the inputs and outputs; // 'input_permutation' is a mapping from old argument numbers to new argument // numbers, whereas 'output_permutation' is the same for outputs. Both @@ -37,6 +40,7 @@ namespace tensorflow { // The rewrite may also change the NodeDef's operator name, and that // name will be used as the name of the generated function. typedef std::function& arg_source_tensors, std::unique_ptr* graph, std::vector* input_permutation, std::vector* output_permutation, NodeDef* node_def)> RewriteSubgraphFn; @@ -61,10 +65,6 @@ typedef std::function* graph_out, FunctionLibraryDefinition* library); + bool reuse_existing_functions, std::unique_ptr* graph_out, + FunctionLibraryDefinition* library); // The attribute that marks function calls produced by the encapsulate // subgraphs pass and that should in turn be compiled via XlaLaunch operators. diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index eef113a3547f0b2f648680d5f51650f70dbbd261..4eb389e0c653f2d32c17f448687f865a44a11b96 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -511,7 +511,6 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) { std::unique_ptr graph_out; s = EncapsulateSubgraphsInFunctions("_encapsulate", "_outside", *graph, /*rewrite_subgraph_fn=*/{}, - /*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph_out, lib_def.get()); if (!s.ok()) return s; @@ -560,8 +559,9 @@ TEST(EncapsulateSubgraphsTest, OneFunction) { Node* b = Input(b1.opts().WithName("B")); // Give nodes 'c' and 'd' names that collide after lowercasing. Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); - Node* d = Binary(b, c, b1.opts().WithName("c").WithControlInput(c).WithAttr( - "_encapsulate", "F1")); + Node* d = Binary(b, c, + b1.opts().WithName("c").WithControlInput(c).WithAttr( + "_encapsulate", "F1")); Binary(a, d, b1.opts().WithName("E")); TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } @@ -614,8 +614,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctions) { Node* c = Unary(a, b1.opts().WithName("C").WithControlInput(control).WithAttr( "_encapsulate", "F1")); - Node* d = - Binary(b, c, b1.opts().WithName("D").WithControlInput(control).WithAttr( + Node* d = Binary(b, c, + b1.opts().WithName("D").WithControlInput(control).WithAttr( "_encapsulate", "F2")); Binary(a, d, b1.opts().WithName("E")); TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); @@ -707,7 +707,7 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) { std::unique_ptr graph; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( "_cluster", "_outside", graph_before_encapsulation, - /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/false, + /*rewrite_subgraph_fn=*/{}, /*reuse_existing_functions=*/false, &graph, &library)); std::vector expected_nodes = {"cluster1", "cluster2", "mul", "x"}; @@ -721,47 +721,6 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) { EXPECT_EQ(expected_edges, GraphEdges(*graph)); } -TEST(EncapsulateSubgraphsTest, ParallelChecking) { - Scope root = Scope::NewRootScope().ExitOnError().WithDevice( - "/job:localhost/replica:0/task:0/cpu:0"); - auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT); - auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT); - auto add1 = ops::Add(root.WithOpName("add1"), x1, x2); - add1.node()->AddAttr("_cluster", "cluster1"); - auto add2 = ops::Add(root.WithOpName("add2"), add1, x2); - add2.node()->AddAttr("_cluster", "cluster1"); - auto out = ops::Mul(root.WithOpName("mul"), x1, add2); - - Graph graph_before_encapsulation(OpRegistry::Global()); - TF_ASSERT_OK(root.ToGraph(&graph_before_encapsulation)); - - FunctionLibraryDefinition library(OpRegistry::Global(), {}); - std::unique_ptr graph; - TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( - "_cluster", "_outside", graph_before_encapsulation, - /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/true, - /*reuse_existing_functions=*/false, &graph, &library)); - - std::vector expected_nodes = { - "add1", "add2", "cluster1", "cluster1_parallel_check/_0", - "mul", "x1", "x2"}; - EXPECT_EQ(expected_nodes, GraphNodes(*graph)); - - std::vector> expected_edges = { - {"add1:0", "add2:0"}, - {"add2:0", "cluster1_parallel_check/_0:0"}, - {"cluster1:0", "cluster1_parallel_check/_0:1"}, - {"cluster1_parallel_check/_0:0", "mul:1"}, - {"x1:0", "add1:0"}, - {"x1:0", "cluster1:0"}, - {"x1:0", "mul:0"}, - {"x2:0", "add1:1"}, - {"x2:0", "add2:1"}, - {"x2:0", "cluster1:1"}, - }; - EXPECT_EQ(expected_edges, GraphEdges(*graph)); -} - const Node* FindNodeByName(const Graph& graph, const string& name) { for (const Node* node : graph.nodes()) { if (node->name() == name) return node; @@ -798,7 +757,8 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( "_encapsulate", "_outside", graph_before, /*rewrite_subgraph_fn=*/ - [&guaranteed_consts](std::unique_ptr* graph_ptr, + [&guaranteed_consts](const std::vector& arg_source_tensors, + std::unique_ptr* graph_ptr, std::vector* input_permutation, std::vector* output_permutation, NodeDef* call_def) { @@ -814,7 +774,6 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { } return Status::OK(); }, - /*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph_after, &library)); EXPECT_EQ(2, guaranteed_consts); } @@ -843,7 +802,8 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( "_encapsulate", "_outside", graph_before, /*rewrite_subgraph_fn=*/ - [&guaranteed_consts](std::unique_ptr* graph_ptr, + [&guaranteed_consts](const std::vector& arg_source_tensors, + std::unique_ptr* graph_ptr, std::vector* input_permutation, std::vector* output_permutation, NodeDef* call_def) { @@ -859,7 +819,6 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { } return Status::OK(); }, - /*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph_after, &library)); // Only 1 runtime const, which is const_guarantee_add1. Add2 has one const // and another non-const, so overall non-const. diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD index 5d211f4d733d8d807426e62dd116092799184f35..5b6692f523658749f7ef48f9d7d89e97d4ce8b09 100644 --- a/tensorflow/compiler/jit/legacy_flags/BUILD +++ b/tensorflow/compiler/jit/legacy_flags/BUILD @@ -16,18 +16,6 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) -cc_library( - name = "encapsulate_subgraphs_pass_flags", - srcs = ["encapsulate_subgraphs_pass_flags.cc"], - hdrs = ["encapsulate_subgraphs_pass_flags.h"], - deps = - [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - cc_library( name = "mark_for_compilation_pass_flags", srcs = ["mark_for_compilation_pass_flags.cc"], diff --git a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc deleted file mode 100644 index 856475f12c8a411cd80c1c1859323304ca4029e0..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for the XLA bridge's encapsulate_subgraphs_pass module. - -#include -#include - -#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static EncapsulateSubgraphsPassFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new EncapsulateSubgraphsPassFlags; - flags->tf_xla_parallel_checking = false; - flag_list = new std::vector({ - Flag("tf_xla_parallel_checking", &flags->tf_xla_parallel_checking, - "Debug tool. Runs both JIT-compiled and interpreted graphs in " - "parallel and verifies they produce the same outputs."), - }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with the XLA bridge's -// encapsulate_subgraphs_pass module. -void AppendEncapsulateSubgraphsPassFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the EncapsulateSubgraphsPassFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -EncapsulateSubgraphsPassFlags* GetEncapsulateSubgraphsPassFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h deleted file mode 100644 index d371bd269dbdfbf737d81490fb877fcf88661a8f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_ -#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_ - -// Legacy flags for the XLA bridge's encapsulate_subgraphs_pass module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with the XLA bridge's -// encapsulate_subgraphs_pass module. -void AppendEncapsulateSubgraphsPassFlags( - std::vector* flag_list); - -// The values of flags associated with the XLA bridge's -// encapsulate_subgraphs_pass module. -typedef struct { - bool tf_xla_parallel_checking; // Debug tool. Runs both JIT-compiled and - // interpreted graphs in parallel and verifies - // they produce the same outputs. -} EncapsulateSubgraphsPassFlags; - -// Return a pointer to the EncapsulateSubgraphsPassFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -EncapsulateSubgraphsPassFlags* GetEncapsulateSubgraphsPassFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 74468266b9e983431732eafc801bc2d2ea682be9..8c3882116dd4f048ea3e32c037bf4139c67a3eb9 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -44,12 +44,6 @@ namespace tensorflow { namespace { -// Returns true if, when executed in TensorFlow, `node` is guaranteed to forward -// a ref tensor input to its output. -static bool AlwaysForwardsRefInput(const Node& node) { - return node.IsIdentity(); -} - bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient // is really a kind of function call and will be handled by @@ -68,20 +62,8 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { // XLA does not offer guaranteed aliasing between the input and output of the // XLA cluster so it can't implement the forward-tensor-ref semantic. Leave // such nodes out of XLA clusters. - if (AlwaysForwardsRefInput(node)) { - for (const Edge* incoming_edge : node.in_edges()) { - if (incoming_edge->IsControlEdge()) { - continue; - } - - Node* incoming_node = incoming_edge->src(); - if (IsRefType(incoming_node->output_type(incoming_edge->src_output()))) { - VLOG(2) << "Not clustering " << node.def().ShortDebugString() - << " because of ref input " << incoming_node->name() << " " - << incoming_node->type_string(); - return false; - } - } + if (HasForwardedRefInput(node)) { + return false; } return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok(); diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index 70bd10336b824b4aaef6520f0b094f52e5a0d626..05b7821b8865d0f210ca9af92370e177d6043e80 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/util/device_name_utils.h" @@ -66,6 +67,9 @@ string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src, } return description; } + +bool AlwaysForwardsRefInput(const Node& node) { return node.IsIdentity(); } + } // namespace Status DeviceToDeviceType(const string& device, DeviceType* device_type) { @@ -77,6 +81,24 @@ Status DeviceToDeviceType(const string& device, DeviceType* device_type) { return Status::OK(); } +bool HasForwardedRefInput(const Node& node) { + if (AlwaysForwardsRefInput(node)) { + for (const Edge* incoming_edge : node.in_edges()) { + if (incoming_edge->IsControlEdge()) { + continue; + } + + Node* incoming_node = incoming_edge->src(); + if (IsRefType(incoming_node->output_type(incoming_edge->src_output()))) { + VLOG(2) << "Node " << node.def().ShortDebugString() << " has ref input " + << incoming_node->name() << " " << incoming_node->type_string(); + return true; + } + } + } + return false; +} + Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { for (int i = 0; i < graph->num_node_ids(); ++i) { // We rely on the node IDs in the cycle detection graph being consecutive diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index 5b673bdc27fccb4228b9e02cbf80d17aa35b5fe5..bcce082aaf6044ff0654efa4d78c0f493a350d00 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -36,6 +36,9 @@ using OrderedNodeSet = std::set; // Returns the DeviceType corresponding to 'device'. Status DeviceToDeviceType(const string& device, DeviceType* device_type); +// Returns true if `node` has a ref tensor input that it forwards to its output. +bool HasForwardedRefInput(const Node& node); + // Creates a graph representation to enable cycle detection when clustering. // This representation handles loops in graph by disconnecting each loop from // the enclosing graph. diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index b1943d3e1a7e321b5a3796a0c6e4f2b5d9ac7018..9beeb3517e4630268944fa8d9b97b98114e8faf9 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -61,14 +61,24 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; TF_RET_CHECK(stream); - VLOG(2) << "Executing computation."; + VLOG(2) << "Executing computation: " << name(); + for (const xla::ShapedBuffer* arg : launch_context.arguments()) { + VLOG(2) << name() << ": " << *arg; + } xla::ExecutableRunOptions run_options; run_options.set_stream(stream); run_options.set_allocator(client->backend().memory_allocator()); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); run_options.set_rng_seed(ctx->step_id()); - auto run_result = executable->Run(launch_context.arguments(), run_options); + xla::StatusOr run_result; + { + // TODO(b/110383871): fix concurrency problems and remove this mutex. + static mutex* mu = new mutex; + mutex_lock lock(*mu); + + run_result = executable->Run(launch_context.arguments(), run_options); + } TF_RETURN_IF_ERROR(run_result.status()); launch_context.PopulateOutputs(ctx, result, run_result.ConsumeValueOrDie()); diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 71e63b110b3b132a57fc291e53a165954c72a03c..37005479dc7dd27cabc945f2753e20477a71549a 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -74,7 +74,7 @@ Status XlaTransferManager::TransferLiteralToDevice( XlaTensor::FromTensor(device_tensor)->shaped_buffer(); VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " " << shaped_buffer.ToString(); - return transfer_manager_->TransferLiteralToDevice(stream_->parent(), literal, + return transfer_manager_->TransferLiteralToDevice(stream_, literal, shaped_buffer); } @@ -83,9 +83,9 @@ Status XlaTransferManager::TransferLiteralFromDevice( const xla::ShapedBuffer& shaped_buffer = XlaTensor::FromTensor(&device_tensor)->shaped_buffer(); - TF_ASSIGN_OR_RETURN(std::unique_ptr literal, - transfer_manager_->TransferLiteralFromDevice( - stream_->parent(), shaped_buffer)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr literal, + transfer_manager_->TransferLiteralFromDevice(stream_, shaped_buffer)); VLOG(1) << "Transfer from device as literal: " << literal->ToString() << " " << shaped_buffer.ToString(); Tensor tensor; diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 0c49286acd3abaf8ea1f12a90d86a1d1ff38b234..11e45d2823da2b623bd3cd45f7147686b05fdb2f 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/kernels/no_op.h" #include "tensorflow/core/kernels/resource_variable_ops.h" #include "tensorflow/core/kernels/sendrecv_ops.h" +#include "tensorflow/core/kernels/shape_ops.h" #include "tensorflow/core/kernels/variable_ops.h" namespace tensorflow { @@ -87,6 +88,46 @@ class XlaAssignVariableOp : public AsyncOpKernel { REGISTER_KERNEL_BUILDER( \ Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \ ReadVariableOp); \ + REGISTER_KERNEL_BUILDER(Name("Shape") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeOp); \ + REGISTER_KERNEL_BUILDER(Name("Shape") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeOp); \ + REGISTER_KERNEL_BUILDER(Name("ShapeN") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeNOp); \ + REGISTER_KERNEL_BUILDER(Name("ShapeN") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeNOp); \ + REGISTER_KERNEL_BUILDER(Name("Size") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + SizeOp); \ + REGISTER_KERNEL_BUILDER(Name("Size") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + SizeOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Rank").Device(DEVICE).HostMemory("output").TypeConstraint("T", \ + TYPES), \ + RankOp); \ REGISTER_KERNEL_BUILDER( \ Name("AssignVariableOp").Device(DEVICE).HostMemory("resource"), \ XlaAssignVariableOp); \ diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc index 96016521ea902274e3ec1dcc35d3d070063eb1ae..74257b09a808a39454eace3b1a9bf57a2e071360 100644 --- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc @@ -178,6 +178,13 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, continue; } + // XLA does not offer guaranteed aliasing between the input and output of + // the XLA cluster so it can't implement the forward-tensor-ref semantic. + // Leave such nodes out of XLA clusters. + if (HasForwardedRefInput(*node)) { + continue; + } + compilation_candidates.insert(node); } diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index b51c11bf6e9b952d9e282b498101ec4f73f87885..9ec6b6b7496528e3c88573865367e53d205e566f 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -51,6 +51,15 @@ py_library( ], ) +py_test( + name = "xla_test_test", + size = "small", + srcs = ["xla_test_test.py"], + deps = [ + ":xla_test", + ], +) + tf_xla_py_test( name = "adagrad_test", size = "small", @@ -539,13 +548,18 @@ tf_xla_py_test( name = "random_ops_test", size = "small", srcs = ["random_ops_test.py"], - # TODO(b/31361304): enable RNG ops on GPU when parallelized. disabled_backends = [ + # TODO(b/110300529): RngNormal doesn't return values with the expected variance + "cpu", + "cpu_ondemand", + # TODO(b/31361304): enable RNG ops on GPU when parallelized. "gpu", ], deps = [ ":xla_test", + "//tensorflow/python:array_ops", "//tensorflow/python:framework", + "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", ], @@ -828,6 +842,18 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "sort_ops_test", + size = "small", + srcs = ["sort_ops_test.py"], + deps = [ + "//tensorflow/compiler/tests:xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + ], +) + tf_xla_py_test( name = "xla_device_test", size = "small", diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 1e4dd32916c3a40282735fb8f75670b0e9ef0dc9..69a99dd1cd8eb9eacbc0b872a691161c1a0dc79f 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -1216,6 +1216,24 @@ class BinaryOpsTest(XLATestCase): np.array([1, 0], dtype=np.int32), expected=np.array([[1, 3], [2, 4]], dtype=dtype)) + def testConjugateTranspose(self): + for dtype in self.complex_types: + self._testBinary( + array_ops.conjugate_transpose, + np.zeros(shape=[1, 0, 4], dtype=dtype), + np.array([1, 2, 0], dtype=np.int32), + expected=np.zeros(shape=[0, 4, 1], dtype=dtype)) + self._testBinary( + array_ops.conjugate_transpose, + np.array([[1 - 1j, 2 + 2j], [3 - 3j, 4 + 4j]], dtype=dtype), + np.array([0, 1], dtype=np.int32), + expected=np.array([[1 + 1j, 2 - 2j], [3 + 3j, 4 - 4j]], dtype=dtype)) + self._testBinary( + array_ops.conjugate_transpose, + np.array([[1 - 1j, 2 + 2j], [3 - 3j, 4 + 4j]], dtype=dtype), + np.array([1, 0], dtype=np.int32), + expected=np.array([[1 + 1j, 3 + 3j], [2 - 2j, 4 - 4j]], dtype=dtype)) + def testCross(self): for dtype in self.float_types: self._testBinary( diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index 4dff5f0f405fb1d936ab2e6bcd82e05e926172c7..e438832a23a670596d12cbc67d71a9f561b82193 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -31,11 +31,13 @@ from tensorflow.python.framework import ops from tensorflow.python.layers import convolutional from tensorflow.python.layers import pooling from tensorflow.python.ops import array_ops +from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import googletest +from tensorflow.python.training import adam class EagerTest(XLATestCase): @@ -47,6 +49,21 @@ class EagerTest(XLATestCase): product = three * five self.assertAllEqual(15, product) + def testGradientTape(self): + with self.test_scope(): + + x = constant_op.constant(1.0) + y = constant_op.constant(10.0) + with backprop.GradientTape(persistent=True) as tape: + tape.watch(x) + tape.watch(y) + a = x + y + x * y + da_dx = tape.gradient(a, x) + da_dy = tape.gradient(a, y) + + self.assertEqual(11.0, da_dx.numpy()) + self.assertEqual(2.0, da_dy.numpy()) + def testExecuteListOutputLen0(self): with self.test_scope(): empty = constant_op.constant([], dtype=dtypes.float32) @@ -160,12 +177,120 @@ class EagerTest(XLATestCase): for _ in range(100): values.append(var.value()) + # The shape, shape_n, size, and rank are tested here because their + # execution kernels (as opposed to compilation only tf2xla kernels) + # are distincts from tf2xla kernels. + + def testShape(self): + def const(value): + return array_ops.shape( + constant_op.constant(value)).numpy() + + def ones(value): + return array_ops.shape( + array_ops.ones(value)).numpy() + + with self.test_scope(): + # Shapes of directly constructed tensors + self.assertAllEqual([], const(3)) + self.assertAllEqual([3], const([1.0, 2.0, 3.0])) + self.assertAllEqual([2, 2], const([[1.0, 2.0], [3.0, 4.0]])) + self.assertAllEqual([2, 1, 2], const([[[1.0, 2.0]], [[3.0, 4.0]]])) + + # Shapes of tensors created by op running on device + # We make this distinction because directly constructed tensors + # are treated differently in a few places that can influence shape: + # - they always have on_host_tensor + # - they and their shapes can be cached + # - they end up on device via a copy, instead of as program output + self.assertAllEqual([], ones([])) + self.assertAllEqual([3], ones([3])) + self.assertAllEqual([2, 2], ones([2, 2])) + self.assertAllEqual([2, 1, 2], ones([2, 1, 2])) + + def testShapeN(self): + with self.test_scope(): + # Shapes of directly constructed tensors + shapes = array_ops.shape_n([ + constant_op.constant(1.0), + constant_op.constant([1.0, 2.0, 3.0]), + constant_op.constant([[1.0, 2.0], [3.0, 4.0]])]) + self.assertAllEqual( + [[], [3], [2, 2]], + [x.numpy().tolist() for x in shapes]) + + # Shapes of tensors created by op running on device + shapes = array_ops.shape_n([ + array_ops.ones([]), + array_ops.ones([3]), + array_ops.ones([2, 2])]) + self.assertAllEqual( + [[], [3], [2, 2]], + [x.numpy().tolist() for x in shapes]) + + def testSize(self): + with self.test_scope(): + self.assertEqual( + 1, array_ops.size(constant_op.constant(1.0)).numpy()) + self.assertEqual( + 3, array_ops.size(constant_op.constant([1.0, 2.0, 3.0])).numpy()) + self.assertEqual( + 4, array_ops.size( + constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy()) + + def testRank(self): + with self.test_scope(): + self.assertEqual( + 0, array_ops.rank(constant_op.constant(1.0)).numpy()) + self.assertEqual( + 1, array_ops.rank(constant_op.constant([1.0, 2.0, 3.0])).numpy()) + self.assertEqual( + 2, array_ops.rank( + constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy()) + + def testAdam(self): + with self.test_scope(): + optimizer = adam.AdamOptimizer(0.1) + x = resource_variable_ops.ResourceVariable(10.0) + with backprop.GradientTape() as tape: + y = x * x + dy_dx = tape.gradient(y, x) + optimizer.apply_gradients([(dy_dx, x)]) + self.assertAlmostEqual(9.9, x.numpy(), places=3) + + def testAdamSparse(self): + with ops.device('/cpu:0'): + # Create 2-D embedding for 3 objects on CPU because sparse/sliced updates + # are not implemented on TPU. + embedding_matrix = resource_variable_ops.ResourceVariable( + array_ops.ones([3, 2])) + + with self.test_scope(): + with backprop.GradientTape() as tape: + embedding = embedding_ops.embedding_lookup(embedding_matrix, [1]) + y = math_ops.reduce_sum(embedding) + dy_dx = tape.gradient(y, embedding_matrix) + self.assertIsInstance(dy_dx, ops.IndexedSlices) + optimizer = adam.AdamOptimizer(0.1) + # The gradient application operations will run on CPU because optimizer + # updates are always collocated with the variable. + optimizer.apply_gradients([(dy_dx, embedding_matrix)]) + + # This assign_add will run on CPU because when an input to an + # operation is a resource, this operation is placed on the resource's + # device by the eager runtime. + embedding_matrix.assign_add(array_ops.ones([3, 2])) + + self.assertAllClose([[2.0, 2.0], + [1.9, 1.9], + [2.0, 2.0]], embedding_matrix.numpy()) + class EagerFunctionTest(XLATestCase): def testBasic(self): with self.test_scope(): - matmul = function.defun(math_ops.matmul, compiled=True) + matmul = function.defun(math_ops.matmul) t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) sq = matmul(t, t, transpose_a=True) self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20]) @@ -187,7 +312,7 @@ class EagerFunctionTest(XLATestCase): def model(x): x = conv(x) return pool(x) - model = function.defun(model, compiled=True) + model = function.defun(model) x = array_ops.ones([1, 4, 4, 1]) y = model(x) @@ -197,7 +322,7 @@ class EagerFunctionTest(XLATestCase): with self.test_scope(): v = resource_variable_ops.ResourceVariable(1.0) - @function.defun(compiled=True) + @function.defun def f(): return v.read_value() @@ -212,7 +337,7 @@ class EagerFunctionTest(XLATestCase): v.assign_add(1.0) return v - f = function.defun(f, compiled=True) + f = function.defun(f) var = f(v) self.assertEqual(2.0, var.numpy()) @@ -240,7 +365,7 @@ class EagerFunctionTest(XLATestCase): d = r2 * v2 return a, b, c, d - foo = function.defun(foo, compiled=True) + foo = function.defun(foo) c1 = [0, 0] c2 = array_ops.ones([2], dtype=dtypes.int32) @@ -262,7 +387,7 @@ class EagerFunctionTest(XLATestCase): with self.test_scope(): v0 = resource_variable_ops.ResourceVariable(5.0) - @function.defun(compiled=True) + @function.defun def f(x): x = v0 * v0 * x return x @@ -275,6 +400,24 @@ class EagerFunctionTest(XLATestCase): self.assertEqual(75, y.numpy()) self.assertEqual(30, dy.numpy()) + def testSliceInDefun(self): + with self.test_scope(): + + @function.defun(compiled=True) + def f(x, y): + return x[0::2, y:, ...] + + x = array_ops.ones([2, 3, 4]) + y = array_ops.ones([], dtype=dtypes.int32) + with backprop.GradientTape() as tape: + tape.watch(x) + tape.watch(y) + z = f(x, y) + dz = tape.gradient(z, x) + + self.assertAllEqual(np.ones([1, 2, 4]), z.numpy()) + self.assertAllEqual((2, 3, 4), dz.shape.as_list()) + class ExcessivePaddingTest(XLATestCase): """Test that eager execution works with TPU flattened tensors. @@ -307,7 +450,7 @@ class ExcessivePaddingTest(XLATestCase): def testAsFunctionInput(self): with self.test_scope(): - @function.defun(compiled=True) + @function.defun def f(x): return math_ops.reduce_sum(x, axis=2) @@ -318,7 +461,7 @@ class ExcessivePaddingTest(XLATestCase): def testAsFunctionOutput(self): with self.test_scope(): - @function.defun(compiled=True) + @function.defun def f(x): return x * constant_op.constant(100 * [[[10.0, 2.0]]]) diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 70be22936a500f53e769607d7ed2c957831af86d..2e71b00ba66dba93c87e565e3a372111de1f362d 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -18,11 +18,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import math + import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import special_math from tensorflow.python.platform import googletest @@ -47,18 +52,18 @@ class RandomOpsTest(XLATestCase): # We use exact equality here. If the random-number generator is producing # deterministic output, all three outputs will be bitwise identical. self.assertTrue((not np.array_equal(y, z)) or - (not np.array_equal(z, w)) or - (not np.array_equal(y, w))) + (not np.array_equal(z, w)) or (not np.array_equal(y, w))) def testRandomUniformIsNotConstant(self): + def rng(dtype): - return random_ops.random_uniform(shape=[2], dtype=dtype, - maxval=1000000) + return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=1000000) for dtype in self._random_types(): self._testRngIsNotConstant(rng, dtype) def testRandomNormalIsNotConstant(self): + def rng(dtype): return random_ops.random_normal(shape=[2], dtype=dtype) @@ -70,13 +75,14 @@ class RandomOpsTest(XLATestCase): for dtype in self._random_types(): with self.test_session() as sess: with self.test_scope(): - x = random_ops.random_uniform(shape=[1000], dtype=dtype, minval=-2, - maxval=33) + x = random_ops.random_uniform( + shape=[1000], dtype=dtype, minval=-2, maxval=33) y = sess.run(x) self.assertTrue((y >= -2).sum() == 1000) self.assertTrue((y < 33).sum() == 1000) def testTruncatedNormalIsNotConstant(self): + def rng(dtype): return random_ops.truncated_normal(shape=[2], dtype=dtype) @@ -84,15 +90,75 @@ class RandomOpsTest(XLATestCase): self._testRngIsNotConstant(rng, dtypes.float32) def testTruncatedNormalIsInRange(self): - count = 10000 + count = 10000000 # TODO(b/34339814): implement inverse erf support for non-F32 types. for dtype in [dtypes.float32]: with self.test_session() as sess: with self.test_scope(): x = random_ops.truncated_normal(shape=[count], dtype=dtype, seed=42) y = sess.run(x) - self.assertTrue((y >= -2).sum() == count) - self.assertTrue((y <= 2).sum() == count) + + def normal_cdf(x): + return .5 * math.erfc(-x / math.sqrt(2)) + + def normal_pdf(x): + return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi) + + def probit(x, sess=sess): + return sess.run(special_math.ndtri(x)) + + a = -2. + b = 2. + mu = 0. + sigma = 1. + + alpha = (a - mu) / sigma + beta = (b - mu) / sigma + z = normal_cdf(beta) - normal_cdf(alpha) + + self.assertTrue((y >= a).sum() == count) + self.assertTrue((y <= b).sum() == count) + + # For more information on these calculations, see: + # Burkardt, John. "The Truncated Normal Distribution". + # Department of Scientific Computing website. Florida State University. + expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma + actual_mean = np.mean(y) + self.assertAllClose(actual_mean, expected_mean, atol=2e-4) + + expected_median = mu + probit( + (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma + actual_median = np.median(y) + self.assertAllClose(actual_median, expected_median, atol=8e-4) + + expected_variance = sigma**2 * (1 + ( + (alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - ( + (normal_pdf(alpha) - normal_pdf(beta)) / z)**2) + actual_variance = np.var(y) + self.assertAllClose(actual_variance, expected_variance, rtol=3e-4) + + def testShuffle1d(self): + with self.test_session() as sess: + with self.test_scope(): + x = math_ops.range(20) + shuffle = random_ops.random_shuffle(x) + result = sess.run(shuffle) + expected = range(20) + # Compare sets to avoid randomness behavior changes but make sure still + # have all the values. + self.assertAllEqual(set(result), set(expected)) + + def testShuffle2d(self): + with self.test_session() as sess: + with self.test_scope(): + x = array_ops.diag(math_ops.range(20)) + shuffle = random_ops.random_shuffle(x) + result = sess.run(shuffle) + expected = np.diag(range(20)).flatten() + # Compare sets to avoid randomness behavior changes but make sure still + # have all the values. + self.assertAllEqual(len(result.flatten()), len(expected)) + self.assertAllEqual(set(result.flatten()), set(expected)) if __name__ == '__main__': diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8ae579abda9854079ee491a7254eb4d09183594a --- /dev/null +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -0,0 +1,131 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for sorting operators.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.compiler.tf2xla.python import xla +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test + + +class XlaSortOpTest(xla_test.XLATestCase): + + def _assertOpOutputMatchesExpected(self, op, args, expected): + with self.test_session() as session: + with self.test_scope(): + placeholders = [ + array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) + for arg in args + ] + feeds = {placeholders[i]: args[i] for i in range(0, len(args))} + output = op(*placeholders) + if isinstance(output, ops.Tensor): + output = [output] + + results = session.run(output, feeds) + for result, v in zip(results, expected): + self.assertAllClose(v, result, rtol=1e-3) + + def testSort(self): + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32]) + for dtype in supported_types.intersection(self.numeric_types): + x = np.arange(101, dtype=dtype) + np.random.shuffle(x) + self._assertOpOutputMatchesExpected( + xla.sort, [x], expected=[np.arange(101, dtype=dtype)]) + + def testTopK(self): + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + # Only bfloat16 is implemented. + bfloat16 = dtypes.bfloat16.as_numpy_dtype + if bfloat16 in self.numeric_types: + for x in [np.arange(20)]: + np.random.shuffle(x) + for k in [0, 1, 2, 10, 20]: + indices = x.argsort()[::-1][:k] + + def topk(v, k=k): + return nn_ops.top_k(v, k=k, sorted=True) + + self._assertOpOutputMatchesExpected( + topk, [x.astype(bfloat16)], + expected=[x[indices].astype(bfloat16), indices]) + + def testTopKZeros(self): + """Tests that positive and negative zeros sort correctly.""" + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + # Only bfloat16 is implemented. + bfloat16 = dtypes.bfloat16.as_numpy_dtype + if bfloat16 not in self.numeric_types: + return + + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.bfloat16) + with self.test_scope(): + topk = nn_ops.top_k(p, k=4) + results = sess.run( + topk, + {p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=bfloat16)}) + self.assertAllEqual( + np.array([3., 0., 0., 0.], dtype=bfloat16), results[0]) + self.assertEqual(list([3, 0, 1, 2]), list(results[1])) + + def testTopKInfinities(self): + """Tests that positive and negative infinity sort correctly.""" + # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. + if self.device in ["XLA_CPU", "XLA_GPU"]: + return + + # Only bfloat16 is implemented. + bfloat16 = dtypes.bfloat16.as_numpy_dtype + if bfloat16 not in self.numeric_types: + return + + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.bfloat16) + with self.test_scope(): + topk = nn_ops.top_k(p, k=6) + results = sess.run(topk, { + p: np.array( + [1, 2, float("inf"), -float("inf"), -1, -2], dtype=bfloat16) + }) + self.assertAllEqual( + np.array( + [float("inf"), 2.0, 1.0, -1.0, -2.0, -float("inf")], + dtype=bfloat16), results[0]) + self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1])) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 689a4a1f4e02f5dd48f64dc94afd0fcb50df8b5b..e610b63e301c75f532db1b58cd26533effea174d 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -201,6 +201,16 @@ class UnaryOpsTest(XLATestCase): expected=np.array([1.54308063, 3.76219569, 10.067662, 27.30823284], dtype=dtype)) + # Disable float16 testing for now + if dtype != np.float16: + x = np.arange(-10, 10, 1).astype(dtype) + with self.test_session() as session: + erf_x = session.run(math_ops.erf(x)) + erfc_x = session.run(math_ops.erfc(x)) + + self._assertOpOutputMatchesExpected(math_ops.erf, x, expected=erf_x) + self._assertOpOutputMatchesExpected(math_ops.erfc, x, expected=erfc_x) + self._assertOpOutputMatchesExpected( math_ops.exp, np.array([[-1, 1]], dtype=dtype), diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index e924fe1e61454aefda622a5a46a0e483d26db5c1..88827cb53bee7bb809d0163d6badcef17e59aa78 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -49,6 +49,32 @@ flags.DEFINE_string('tf_xla_flags', None, 'Value to set the TF_XLA_FLAGS environment variable to') +def parse_disabled_manifest(manifest_content): + comments_re = re.compile('#.*$') + disabled_tests = [] + disabled_method_types = [] + for l in manifest_content.splitlines(): + stripped = comments_re.sub('', l).strip() + if not stripped: + continue + entry = stripped.split(' ') + if len(entry) == 1: + disabled_tests.append(entry[0]) + elif len(entry) == 2: + disabled_method_types.append((entry[0], entry[1].strip().split(','))) + else: + raise ValueError('Bad entry in manifest file.') + + disabled_regex = '|'.join(disabled_tests) + method_types_filter = dict() + for method, types in disabled_method_types: + method_types_filter[method] = set([ + dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype + for name in types + ]) + return disabled_regex, method_types_filter + + class XLATestCase(test.TestCase): """XLA test cases are parameterized test cases.""" @@ -85,38 +111,21 @@ class XLATestCase(test.TestCase): # Parse the manifest file, if any, into a regex identifying tests to # disable - self.disabled_regex = None - self._method_types_filter = dict() # TODO(xpan): Make it text proto if it doesn't scale. # Each line of the manifest file specifies an entry. The entry can be # 1) TestNameRegex // E.g. CumprodTest.* Or # 2) TestName TypeName // E.g. AdamOptimizerTest.testSharing DT_BFLOAT16 # The 1) disables the entire test. While 2) only filter some numeric types # so that they are not used in those tests. + self.disabled_regex = None + self._method_types_filter = {} if FLAGS.disabled_manifest is not None: - comments_re = re.compile('#.*$') - manifest_file = open(FLAGS.disabled_manifest, 'r') - disabled_tests = [] - disabled_method_types = [] - for l in manifest_file.read().splitlines(): - if not l: - continue - entry = comments_re.sub('', l).strip().split(' ') - if len(entry) == 1: - disabled_tests.append(entry[0]) - elif len(entry) == 2: - disabled_method_types.append( - (entry[0], entry[1].strip().split(','))) - else: - raise ValueError('Bad entry in manifest file.') - - self.disabled_regex = re.compile('|'.join(disabled_tests)) - for method, types in disabled_method_types: - self._method_types_filter[method] = set([ - dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype - for name in types]) - manifest_file.close() + with open(FLAGS.disabled_manifest, 'r') as manifest_file: + disabled_regex, self._method_types_filter = ( + parse_disabled_manifest(manifest_file.read())) + if disabled_regex: + self.disabled_regex = re.compile(disabled_regex) if FLAGS.tf_xla_flags is not None: os.environ['TF_XLA_FLAGS'] = FLAGS.tf_xla_flags diff --git a/tensorflow/compiler/tests/xla_test_test.py b/tensorflow/compiler/tests/xla_test_test.py new file mode 100644 index 0000000000000000000000000000000000000000..24664451579445edaadb335c30d253ee55f003da --- /dev/null +++ b/tensorflow/compiler/tests/xla_test_test.py @@ -0,0 +1,44 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the XLATestCase test fixture base class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.platform import test + + +class XlaTestCaseTestCase(test.TestCase): + + def testManifestEmptyLineDoesNotCatchAll(self): + manifest = """ +testCaseOne +""" + disabled_regex, _ = xla_test.parse_disabled_manifest(manifest) + self.assertEqual(disabled_regex, "testCaseOne") + + def testManifestWholeLineCommentDoesNotCatchAll(self): + manifest = """# I am a comment +testCaseOne +testCaseTwo +""" + disabled_regex, _ = xla_test.parse_disabled_manifest(manifest) + self.assertEqual(disabled_regex, "testCaseOne|testCaseTwo") + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index cd57452302fcbde37d79ce760a80615a76d7ad8c..49c57a9f51369cba965d12fa12e96aced01ae6e8 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -406,12 +406,39 @@ cc_library( ], ) +cc_library( + name = "validate_control_flow", + srcs = ["validate_control_flow.cc"], + hdrs = ["validate_control_flow.h"], + deps = [ + "//tensorflow/core:graph", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "validate_control_flow_test", + srcs = ["validate_control_flow_test.cc"], + deps = [ + ":validate_control_flow", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:while_loop", + "//tensorflow/core:lib", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "functionalize_control_flow", srcs = ["functionalize_control_flow.cc"], hdrs = ["functionalize_control_flow.h"], deps = [ ":tf2xla_util", + ":validate_control_flow", "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla/ops:xla_ops", @@ -462,3 +489,13 @@ cc_library( "//tensorflow/core:protos_all_cc", ], ) + +tf_cc_test( + name = "xla_op_registry_test", + srcs = ["xla_op_registry_test.cc"], + deps = [ + ":xla_compiler", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 42585ad4d8a17d71146e48b69f9fa56f9ff24c3e..b9ed44e354ee1fa4cdb2aae7421116c636d07570 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/compiler/tf2xla/validate_control_flow.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" @@ -1438,7 +1439,15 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, // connected to all source nodes in the graph. Many graphs violate this // invariant. std::vector cf_info; - TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info)); + std::vector unreachable_nodes; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + BuildAndValidateControlFlowInfo(graph, &cf_info, &unreachable_nodes), + "FunctionalizeControlFlow failed"); + if (!unreachable_nodes.empty()) { + return errors::InvalidArgument( + "The following nodes are unreachable from the source in the graph: ", + tensorflow::str_util::Join(unreachable_nodes, ", ")); + } // Builds Frames, indexed by name. std::unordered_map frames; @@ -1458,10 +1467,6 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, frame.parent = parent; frame.name = cf.frame_name; ++parent->num_children; - } else if (frame.parent != parent) { - return errors::InvalidArgument("Mismatched parent frames for ", - cf.frame->id(), ": ", parent->name, " vs ", - frame.parent->name); } if (IsEnter(node)) { @@ -1471,12 +1476,6 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, &arg.is_loop_invariant)); frame.args.push_back(arg); } else if (IsLoopCond(node)) { - if (frame.loop_cond) { - return errors::InvalidArgument( - "Loop ", cf.frame_name, - " has more than one LoopCond node: ", node->name(), " and ", - frame.loop_cond->name()); - } frame.loop_cond = node; } frame.nodes.insert(node); diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index edd2ab6301ee891c433639ce300cde0c72929cea..c431a4b9cf5651aaa9baeb8ede12059ead6e2e39 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -79,6 +79,7 @@ tf_kernel_library( "shape_util.cc", "slice_op.cc", "softmax_op.cc", + "sort_ops.cc", "spacetobatch_op.cc", "spacetodepth_op.cc", "split_op.cc", @@ -87,6 +88,7 @@ tf_kernel_library( "strided_slice_op.cc", "tensor_array_ops.cc", "tile_ops.cc", + "topk_op.cc", "training_ops.cc", "transpose_op.cc", "unary_ops.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 8b9b026643cf35216a2082dfcce9270c017bd14f..d48c6eea754f75a8879d3938f233a6a591d26d0d 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -48,11 +48,11 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Building If: " << input_types_.size() << " inputs"; - std::vector inputs(input_types_.size()); std::vector arguments(input_types_.size()); for (int i = 0; i < input_types_.size(); ++i) { XlaCompiler::Argument& arg = arguments[i]; DataType type = ctx->input_type(i + 1); + if (type == DT_RESOURCE) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(i + 1, &resource)); @@ -60,7 +60,6 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { arg.initialized = resource->initialized(); arg.kind = XlaCompiler::Argument::kResource; arg.resource_kind = resource->kind(); - OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); arg.type = resource->type(); arg.shape = resource->shape(); @@ -79,7 +78,6 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { arg.kind = XlaCompiler::Argument::kParameter; arg.type = input_types_[i]; arg.shape = ctx->InputShape(i + 1); - inputs[i] = ctx->Input(i + 1); VLOG(2) << "Arg type: " << DataTypeString(arg.type) << " shape: " << arg.shape.DebugString(); } @@ -100,6 +98,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_, arguments, &else_result)); + bool has_tensor_array_gradients = false; for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) { for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) { XlaResource* resource; @@ -121,9 +120,21 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { for (const auto& gradient : resource->tensor_array_gradients()) { arg.tensor_array_gradients.insert(gradient.first); } + if (!resource->tensor_array_gradients().empty()) + has_tensor_array_gradients = true; } } + // Recompile the functions to update the argument shapes for tensor arrays. + if (has_tensor_array_gradients) { + then_result = {}; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, then_branch_, + arguments, &then_result)); + else_result = {}; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_, + arguments, &else_result)); + } + // Check that both branches have identical input shapes. OP_REQUIRES(ctx, then_result.xla_input_shapes.size() == 1, errors::FailedPrecondition("Expected one input shape")); @@ -175,6 +186,19 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { "Mismatch in resource of then and else branch for resource ", i)); } + int num_inputs = then_result.input_mapping.size(); + std::vector inputs(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + int input_num = then_result.input_mapping[i] + 1; + if (ctx->input_type(input_num) == DT_RESOURCE) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); + OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); + } else { + inputs[i] = ctx->Input(i + 1); + } + } + xla::XlaOp outputs = b->Conditional(ctx->Input(0), b->Tuple(inputs), *then_result.computation, b->Tuple(inputs), *else_result.computation); diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index 7e9de3ef9b245c113cc143128fe58e7e017a361c..c3326b4d11432fb17a02e9a336a70d88bf40da6b 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -27,7 +27,7 @@ class MirrorPadOp : public XlaOpKernel { xla::StatusOr DoMirrorPad(const xla::XlaOp& t, const xla::Shape& original_shape, - const xla::Literal& pad_literal, + const xla::LiteralSlice& pad_literal, xla::XlaBuilder* b) { xla::XlaOp accum = t; for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0; diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index 7c95475e7b1f02183e44f73f116a4aeb25f05c09..17b85338f75d6295c6b4a1bf1db24aa641eab020 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -63,8 +63,8 @@ class PadOp : public XlaOpKernel { int before = pad_literal.Get({i, 0}); int after = pad_literal.Get({i, 1}); OP_REQUIRES(ctx, before >= 0 && after >= 0, - errors::InvalidArgument("Paddings must be non-negative: ", - before, " ", after)); + errors::InvalidArgument( + "Paddings must be non-negative: ", before, " ", after)); dim->set_edge_padding_low(before); dim->set_edge_padding_high(after); } diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 39149d56adb24404e1788a634913615298ee5a33..aa4d242a11b2fb28b4f1949ec5249ae6c8a1b101 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -17,6 +17,10 @@ limitations under the License. // TODO(misard,phawkins): handle random number generator seeds/states correctly. // TODO(misard,phawkins): add tests. +#include + +#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -56,6 +60,78 @@ class RandomUniformOp : public XlaOpKernel { REGISTER_XLA_OP(Name("RandomUniform").CompileTimeConstInput("shape"), RandomUniformOp); +class RandomShuffleOp : public XlaOpKernel { + public: + explicit RandomShuffleOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + auto builder = ctx->builder(); + xla::XlaOp input = ctx->Input(0); + TensorShape input_shape = ctx->InputShape(0); + const int64 n = input_shape.dim_size(0); + int64 num_elements = 1; + for (tensorflow::TensorShapeDim dimension : input_shape) { + num_elements *= dimension.size; + } + if (num_elements <= 1 || n <= 1) { + // No shuffling is required, so copy input directly to output + ctx->SetOutput(0, input); + } else { + // Generate the random swaps for the indices. + auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n}); + auto swaps = + builder->RngUniform(builder->ConstantR0(0), + builder->ConstantR0(n), swaps_shape); + + // Generate range(n) as the initial value for the indices to be swapped. + xla::XlaOp indices; + TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, n, &indices)); + + // Swap the indices at i and swaps[i]. + auto swap_body_fn = [&](xla::XlaOp i, + gtl::ArraySlice loop_vars, + xla::XlaBuilder* builder) + -> xla::StatusOr> { + auto swaps = loop_vars[0]; + auto indices = loop_vars[1]; + i = builder->Reshape(i, {1}); + // temp = indices[i] + auto temp = builder->DynamicSlice(indices, i, {1}); + // swap_index = swaps[i] + auto swap_index = builder->DynamicSlice(swaps, i, {1}); + // swap_value = indices[swaps[i]] + auto swap_value = builder->DynamicSlice(indices, swap_index, {1}); + // indices[i] = indices[swaps[i]] + indices = builder->DynamicUpdateSlice(indices, swap_value, i); + // indices[swaps[i]] = temp + indices = builder->DynamicUpdateSlice(indices, temp, swap_index); + return std::vector{swaps, indices}; + }; + // for i in range(n): + auto swap_loop_result = + XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, + "indices_swap_loop", builder) + .ValueOrDie(); + auto swapped_indices = swap_loop_result[1]; + + // Gather the data using the swapped indices as the shuffled order. + auto indices_tensor_shape = TensorShape({n}); + DataType type = ctx->expected_output_dtype(0); + xla::XlaOp gather; + OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices, + indices_tensor_shape, + /*axis=*/0, /*indices_are_nd=*/false, type, + DT_INT32, builder, &gather)); + ctx->SetOutput(0, gather); + } + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleOp); +}; + +REGISTER_XLA_OP(Name("RandomShuffle"), RandomShuffleOp); + class RandomUniformIntOp : public XlaOpKernel { public: explicit RandomUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} @@ -131,58 +207,44 @@ class TruncatedNormalOp : public XlaOpKernel { xla::XlaBuilder* b = ctx->builder(); - auto two_sd = [dtype](bool negate, xla::XlaBuilder* b) { - return XlaHelpers::FloatLiteral(b, dtype, negate ? -2.0 : 2.0); - }; - auto out_of_range_mask = [two_sd](xla::XlaOp candidate, - xla::XlaBuilder* b) { - xla::XlaOp too_large = b->Gt(candidate, two_sd(false, b)); - xla::XlaOp too_small = b->Lt(candidate, two_sd(true, b)); - return b->Or(too_large, too_small); + auto normal_cdf = [](double x) { + return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0; }; - // The algorithm we're using is roughly: - // - // while (any(candidate < mean-2*sd || candidate > mean+2*sd)) { - // out_of_range_mask := candidate < mean-2*sd || candidate > mean+2*sd - // candidate = select(out_of_range_mask, rng_normal(), candidate) - // } - std::vector initial_values = { - // The current candidate. - b->Broadcast(XlaHelpers::Zero(b, dtype), shape.dim_sizes()), - // The to_resample mask, where 'true' identifies a location in the - // current candidate that is out of range and must be regenerated. - b->Broadcast(b->ConstantR0(true), shape.dim_sizes()), - // Is any element in the mask true? - b->ConstantR0(true)}; - auto condition = [&](gtl::ArraySlice values, - xla::XlaBuilder* b) -> xla::StatusOr { - // Continue while any element in the mask is true. - return values[2]; - }; - auto body = - [&](gtl::ArraySlice values, - xla::XlaBuilder* b) -> xla::StatusOr> { - xla::XlaOp candidate = values[0]; - xla::XlaOp to_resample = values[1]; - xla::XlaOp mean = XlaHelpers::Zero(b, dtype); - xla::XlaOp stddev = XlaHelpers::One(b, dtype); - candidate = b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape), - candidate); - // Compute a new to_resample mask, and determine whether any value is - // still out of range. - to_resample = out_of_range_mask(candidate, b); - TF_ASSIGN_OR_RETURN(xla::XlaOp done, Any(to_resample, b)); - return std::vector{candidate, to_resample, done}; - }; - auto result = - XlaWhileLoop(condition, body, initial_values, "truncated_normal", b); - OP_REQUIRES_OK(ctx, result.status()); - ctx->SetOutput(0, result.ValueOrDie()[0]); + const double kA = -2.0; + const double kB = 2.0; + const double kMu = 0.0; + const double kSigma = 1.0; + const double kAlpha = (kA - kMu) / kSigma; + const double kBeta = (kB - kMu) / kSigma; + const double kAlphaNormalCdf = normal_cdf(kAlpha); + const double kBetaNormalCdf = normal_cdf(kBeta); + const double kZ = kBetaNormalCdf - kAlphaNormalCdf; + + xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype, 1.0); + xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype, 2.0); + xla::XlaOp sqrt_2 = XlaHelpers::FloatLiteral(b, dtype, std::sqrt(2.0)); + xla::XlaOp min_positive = + XlaHelpers::FloatLiteral(b, dtype, std::numeric_limits::min()); + + xla::XlaOp z = XlaHelpers::FloatLiteral(b, dtype, kZ); + xla::XlaOp alpha_normal_cdf = + XlaHelpers::FloatLiteral(b, dtype, kAlphaNormalCdf); + + auto uniform = b->RngUniform(min_positive, one, xla_shape); + // probit(p) = sqrt(2) * erfinv(2*p-1) + auto p = b->Add(alpha_normal_cdf, b->Mul(z, uniform)); + auto erfinv_input = b->Sub(b->Mul(p, two), one); + auto erfinv_or_status = ErfInv(b, erfinv_input); + OP_REQUIRES_OK(ctx, erfinv_or_status.status()); + auto probit = b->Mul(sqrt_2, erfinv_or_status.ValueOrDie()); + ctx->SetOutput(0, probit); } }; -REGISTER_XLA_OP(Name("TruncatedNormal").CompileTimeConstInput("shape"), +REGISTER_XLA_OP(Name("TruncatedNormal") + .CompileTimeConstInput("shape") + .TypeConstraint("dtype", DT_FLOAT), TruncatedNormalOp); } // anonymous namespace diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 4fd5bfd03999a7f8b7bb081cc4b03aa1434d4c3d..44510c731e0dcb94c7f864053354b7a6a42d93f9 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -56,9 +56,9 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { // Evaluate the constant, reshaping to a 1-vector if it is a scalar. xla::Literal axes_literal; - OP_REQUIRES_OK(ctx, - ctx->ConstantInputReshaped( - 1, {axes_tensor_shape.num_elements()}, &axes_literal)); + OP_REQUIRES_OK( + ctx, ctx->ConstantInputReshaped(1, {axes_tensor_shape.num_elements()}, + &axes_literal)); VLOG(1) << "data shape: " << data_shape.DebugString(); VLOG(1) << "axes : " << axes_literal.ToString(); diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index 2c31f8d90891924f6f86a54ccf548de4df87f3bd..bc3d0bf5dfe9e5af8e50a25e27db7148e05e0cfd 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -55,9 +55,10 @@ Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) { // The type-specific part of the implementation of Range. template -Status CreateRangeTensor(const xla::Literal& start_literal, - const xla::Literal& limit_literal, - const xla::Literal& delta_literal, Tensor* output) { +Status CreateRangeTensor(const xla::LiteralSlice& start_literal, + const xla::LiteralSlice& limit_literal, + const xla::LiteralSlice& delta_literal, + Tensor* output) { T start = start_literal.Get({}); T limit = limit_literal.Get({}); T delta = delta_literal.Get({}); @@ -67,13 +68,13 @@ Status CreateRangeTensor(const xla::Literal& start_literal, } if (delta > 0) { if (start > limit) { - return errors::InvalidArgument("Requires start <= limit when delta > 0: ", - start, "/", limit); + return errors::InvalidArgument( + "Requires start <= limit when delta > 0: ", start, "/", limit); } } else { if (start < limit) { - return errors::InvalidArgument("Requires start >= limit when delta < 0: ", - start, "/", limit); + return errors::InvalidArgument( + "Requires start >= limit when delta < 0: ", start, "/", limit); } } int64 size = diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 05354bca5bb089703fdcceb6f44648bbb98d004b..d59720bef742c7441ee01a954247013559bb909c 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -43,7 +43,7 @@ class ShapeOp : public XlaOpKernel { DataType out_dtype_; }; -REGISTER_XLA_OP(Name("Shape"), ShapeOp); +REGISTER_XLA_OP(Name("Shape").CompilationOnly(), ShapeOp); class ShapeNOp : public XlaOpKernel { public: @@ -65,7 +65,7 @@ class ShapeNOp : public XlaOpKernel { private: DataType out_dtype_; }; -REGISTER_XLA_OP(Name("ShapeN"), ShapeNOp); +REGISTER_XLA_OP(Name("ShapeN").CompilationOnly(), ShapeNOp); class RankOp : public XlaOpKernel { public: @@ -81,7 +81,7 @@ class RankOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Rank"), RankOp); +REGISTER_XLA_OP(Name("Rank").CompilationOnly(), RankOp); class SizeOp : public XlaOpKernel { public: @@ -100,7 +100,7 @@ class SizeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Size"), SizeOp); +REGISTER_XLA_OP(Name("Size").CompilationOnly(), SizeOp); class ExpandDimsOp : public XlaOpKernel { public: @@ -189,10 +189,9 @@ class SqueezeOp : public XlaOpKernel { if (!wrapped_squeeze_dims.empty()) { if (wrapped_squeeze_dims.count(i) > 0) { OP_REQUIRES(ctx, existing_dim == 1, - errors::InvalidArgument("Tried to explicitly squeeze " - "dimension ", - i, " but dimension was not 1: ", - existing_dim)); + errors::InvalidArgument( + "Tried to explicitly squeeze dimension ", i, + " but dimension was not 1: ", existing_dim)); } else { // This dimension is not being squeezed. new_shape.push_back(existing_dim); diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..204ae8458214a0d0f049cff32ea99540b6f7fbd6 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc @@ -0,0 +1,36 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" + +namespace tensorflow { +namespace { + +class XlaSortOp : public XlaOpKernel { + public: + explicit XlaSortOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + xla::XlaBuilder* const b = context->builder(); + context->SetOutput(0, b->Sort(context->Input(0))); + } +}; + +REGISTER_XLA_OP(Name("XlaSort"), XlaSortOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 8958b2e7701e62d802e37a895c14b662ecf9786a..9b540585416ded663467d32f25ceceeaa52f069a 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -134,7 +134,7 @@ class SplitVOp : public XlaOpKernel { errors::InvalidArgument( "Number of ways to split should be > 0, but got ", num_split)); - // check that sizes are correct + // Check that sizes are correct. int total_split_size = 0; int neg_one_dim = -1; std::vector split_sizes_vec(num_split, -1); @@ -148,7 +148,7 @@ class SplitVOp : public XlaOpKernel { " number of elements as the output. Got ", split_size_shape.dims(), "-D and ", split_size_shape.num_elements(), " elements")); - // get the dimension of this split + // Get the dimension of this split. xla::Literal split_size_literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &split_size_literal)); diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index a99d4ddc7c4956f7144512a9bdf6f4c2eb0f944f..58c5dc5aa9c161c9765ada4ee6eeebd72fe08eef 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -163,51 +163,6 @@ xla::XlaOp RandomUniform(xla::XlaBuilder* builder, const xla::XlaOp& seed, return floats; } -// Approximation for the inverse error function from -// Giles, M., "Approximating the erfinv function". -// The approximation has the form: -// w = -log((1 - x) * (1 + x)) -// if ( w < 5 ) { -// w = w - 2.5 -// p = sum_{i=1}^n lq[i]*w^i -// } else { -// w = sqrt(w) - 3 -// p = sum_{i=1}^n gq[i]*w^i -// } -// return p*x -xla::XlaOp ErfInvF32(xla::XlaBuilder* b, const xla::XlaOp& x, - const TensorShape& shape) { - constexpr int kDegree = 9; - constexpr std::array w_less_than_5_constants = { - 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, - -4.39150654e-06f, 0.00021858087f, -0.00125372503f, - -0.00417768164f, 0.246640727f, 1.50140941f}; - constexpr std::array w_greater_than_5_constants = { - -0.000200214257f, 0.000100950558f, 0.00134934322f, - -0.00367342844f, 0.00573950773f, -0.0076224613f, - 0.00943887047f, 1.00167406f, 2.83297682f}; - - auto one = b->ConstantR0(1.0); - auto w = b->Neg(b->Log(b->Mul(b->Sub(one, x), b->Add(one, x)))); - - auto lt = b->Lt(w, b->ConstantR0(5.0)); - auto coefficient = [&](int i) { - return b->Select( - lt, - b->Broadcast(b->ConstantR0(w_less_than_5_constants[i]), - shape.dim_sizes()), - b->Broadcast(b->ConstantR0(w_greater_than_5_constants[i]), - shape.dim_sizes())); - }; - w = b->Select(lt, b->Sub(w, b->ConstantR0(2.5f)), - b->Sub(b->SqrtF32(w), b->ConstantR0(3.0f))); - auto p = coefficient(0); - for (int i = 1; i < kDegree; ++i) { - p = b->Add(coefficient(i), b->Mul(p, w)); - } - return b->Mul(p, x); -} - } // namespace class StatelessRandomUniformOp : public XlaOpKernel { @@ -259,8 +214,10 @@ class StatelessRandomNormalOp : public XlaOpKernel { RandomUniform(builder, seed, shape, std::nextafter(-1.0f, 0.0f), 1.0); // Convert uniform distribution to normal distribution by computing // sqrt(2) * erfinv(x) + auto erfinv_or_status = ErfInv(builder, uniform); + OP_REQUIRES_OK(ctx, erfinv_or_status.status()); auto normal = builder->Mul(builder->ConstantR0(std::sqrt(2.0)), - ErfInvF32(builder, uniform, shape)); + erfinv_or_status.ValueOrDie()); ctx->SetOutput(0, normal); } diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..cbe3c8aaff02e1a4b19f295216772b2004ccaf70 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -0,0 +1,158 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/no_op.h" + +namespace tensorflow { +namespace { + +class TopKOp : public XlaOpKernel { + public: + explicit TopKOp(OpKernelConstruction* context) : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_)); + } + + void Compile(XlaOpKernelContext* context) override { + int64 k; + OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(1, &k)); + OP_REQUIRES(context, k >= 0, + errors::InvalidArgument("Need k >= 0, got ", k)); + const TensorShape input_shape = context->InputShape(0); + OP_REQUIRES(context, input_shape.dims() >= 1, + errors::InvalidArgument("input must be >= 1-D, got shape ", + input_shape.DebugString())); + OP_REQUIRES( + context, input_shape.dim_size(input_shape.dims() - 1) >= k, + errors::InvalidArgument("input must have at least k columns. Had ", + input_shape.dim_size(input_shape.dims() - 1), + ", needed ", k)); + + OP_REQUIRES( + context, input_shape.dims() == 1, + errors::Unimplemented("TopK is implemented for 1-D inputs, got shape ", + input_shape.DebugString())); + + const int64 n = input_shape.dim_size(0); + OP_REQUIRES(context, n < (1 << 16), + errors::Unimplemented( + "TopK is implemented for sizes up to 2**16, got shape ", + input_shape.DebugString())); + + xla::XlaBuilder* const b = context->builder(); + if (input_shape.dim_size(0) < k) { + k = input_shape.dim_size(0); + } + const xla::XlaOp input_bf16 = context->Input(0); + xla::XlaOp iota_s32; + OP_REQUIRES_OK(context, XlaHelpers::Iota(b, DT_INT32, n, &iota_s32)); + + // TODO(b/73891930): add a key-value sort to HLO, rather than using + // bit-packing tricks here. + + xla::XlaOp zero = b->ConstantR0(0); + + // max can either be 0x7FFFFFFF or 0x8000000. Neither choice is totally + // ideal. The implications of the choice are: + // + // 0x7FFFFFFF + // 1. +0.0 > -0.0 + // 2. The elements of the inputs and outputs are bitwise identical. + // 3. The sort is unstable since a later +0.0 will appear before an earlier + // -0.0. + // + // 0x8000000 + // 1. +0.0 == -0.0 + // 2. All -0.0 in the input are replaced with +0.0 in the output. + // 3. The sort is stable. + xla::XlaOp max = b->ConstantR0(0x80000000); + xla::XlaOp index_mask = b->ConstantR0(0x0000FFFF); + xla::XlaOp value_mask = b->ConstantR0(0xFFFF0000); + + // Convert to from bf16 to f32. The lower 16-bits are zero due to the + // definition of bf16. + xla::XlaOp input_f32 = b->ConvertElementType(input_bf16, xla::F32); + + // Negate the input to reverse sort it. The lower 16-bits are zero, because + // negating a float is just inverting the high-bit. + xla::XlaOp negative_input_f32 = b->Neg(input_f32); + + // Convert to a sign magnitude integer. The lower 16-bits are zero, since + // bitcast convert doesn't change any bits. + xla::XlaOp negative_input_sm32 = + b->BitcastConvertType(negative_input_f32, xla::S32); + + // Convert from sign magnitude integer to two's complement integer. The + // lower 16-bits are zero on both sides of the select. On the false side, + // the value is unchanged, and on the true side, the lower 16-bits of max + // are all zero, so the lower 16-bits of the result of the subtraction will + // also be zero. + xla::XlaOp negative_input_s32 = + b->Select(b->Lt(negative_input_sm32, zero), + b->Sub(max, negative_input_sm32), negative_input_sm32); + + // In order for the Or with iota_s32 to to work properly, the lower 16-bits + // of negative_input_32 must be zero. + + // Pack elements as: + // * upper 16 bits are the value + // * lower 16 bits are the index. + xla::XlaOp packed_s32 = b->Or(negative_input_s32, iota_s32); + + // TODO(phawkins): use a more efficient algorithm that does not require a + // full sort. + xla::XlaOp sorted_s32 = b->Slice(b->Sort(packed_s32), + /*start_indices=*/{0}, + /*limit_indices=*/{k}, + /*strides=*/{1}); + + // Unpack the value/index. + xla::XlaOp indices_s32 = b->And(sorted_s32, index_mask); + xla::XlaOp negative_values_s32 = b->And(sorted_s32, value_mask); + + // Convert from two's complement integer to sign magnitude integer. + xla::XlaOp negative_values_sm32 = + b->Select(b->Lt(negative_values_s32, zero), + b->Sub(max, negative_values_s32), negative_values_s32); + + xla::XlaOp negative_values_f32 = + b->BitcastConvertType(negative_values_sm32, xla::F32); + + // Negate the values to get back the original inputs. + xla::XlaOp values_f32 = b->Neg(negative_values_f32); + + // Convert from f32 to bf16. + xla::XlaOp values_bf16 = b->ConvertElementType(values_f32, xla::BF16); + + context->SetOutput(0, values_bf16); + context->SetOutput(1, indices_s32); + } + + private: + bool sorted_; +}; + +REGISTER_XLA_OP( + Name("TopKV2").CompileTimeConstInput("k").TypeConstraint("T", DT_BFLOAT16), + TopKOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index c167642174b328a968d7f7ce1f0ad6e0ab8a7a68..ef5aae81a8d73ba326d4116d48b9eebee3c44098 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -32,7 +32,8 @@ namespace { class TransposeOp : public XlaOpKernel { public: - explicit TransposeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit TransposeOp(OpKernelConstruction* ctx, bool conjugate = false) + : XlaOpKernel(ctx), conjugate_(conjugate) {} void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_shape = ctx->InputShape(0); @@ -78,19 +79,37 @@ class TransposeOp : public XlaOpKernel { errors::InvalidArgument(i, " is missing from 'perm' argument.")); } + xla::XlaOp transposed; // 0-D, 1-D, and identity transposes do nothing. if (dims <= 1 || is_identity) { - ctx->SetOutput(0, ctx->Input(0)); - return; + transposed = ctx->Input(0); + } else { + transposed = ctx->builder()->Transpose(ctx->Input(0), transposed_order); } - ctx->SetOutput(0, - ctx->builder()->Transpose(ctx->Input(0), transposed_order)); + // Conjugate the transposed result if this is ConjugateTransposeOp. + if (conjugate_) { + ctx->SetOutput(0, ctx->builder()->Conj(transposed)); + } else { + ctx->SetOutput(0, transposed); + } } + + private: + const bool conjugate_; +}; + +class ConjugateTransposeOp : public TransposeOp { + public: + explicit ConjugateTransposeOp(OpKernelConstruction* ctx) + : TransposeOp(ctx, /*conjugate=*/true) {} }; REGISTER_XLA_OP(Name("Transpose").CompileTimeConstInput("perm"), TransposeOp); +REGISTER_XLA_OP(Name("ConjugateTranspose").CompileTimeConstInput("perm"), + ConjugateTransposeOp); + // InvertPermutation frequently forms part of the gradient of Transpose. // // inv = InvertPermutationOp(T p) takes a permutation of diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 71a9fd051bfc8db09738a4bfe8ddde447895ecf0..1d078de2114fece4dd8f894a39635245e98dd567 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -16,9 +16,11 @@ limitations under the License. // Native XLA implementations of simple unary Ops #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -185,5 +187,49 @@ XLAJIT_MAKE_UNARY(Imag, b->Imag(x)); #undef XLAJIT_MAKE_UNARY +// Erf/Erfc. For x in (-1, 1), the erf approximation is used; erfc polynomial +// is used outside of this range. +class ErfOp : public XlaOpKernel { + public: + explicit ErfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + xla::PrimitiveType primitive_type; + xla::XlaOp one = XlaHelpers::One(b, input_type(0)); + xla::XlaOp x = ctx->Input(0); + xla::XlaOp abs_x = b->Abs(x); + + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(input_type(0), &primitive_type)); + + auto y = + b->Select(b->Gt(abs_x, one), b->Sub(one, Erfc(b, x, primitive_type)), + Erf(b, x, primitive_type)); + ctx->SetOutput(0, y); + } +}; +REGISTER_XLA_OP(Name("Erf"), ErfOp); + +class ErfcOp : public XlaOpKernel { + public: + explicit ErfcOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp one = XlaHelpers::One(b, input_type(0)); + xla::XlaOp x = ctx->Input(0); + xla::XlaOp abs_x = b->Abs(x); + + xla::PrimitiveType primitive_type; + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(input_type(0), &primitive_type)); + + auto y = + b->Select(b->Lt(abs_x, one), b->Sub(one, Erf(b, x, primitive_type)), + Erfc(b, x, primitive_type)); + ctx->SetOutput(0, y); + } +}; +REGISTER_XLA_OP(Name("Erfc"), ErfcOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index 526694d5a0c7124e1696f34b516f3b202462bc19..ee0bb91a6b747ffc9e28e19dd4869a5b2cc43501 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -71,8 +71,8 @@ xla::StatusOr BatchDot(xla::XlaBuilder* builder, xla::XlaOp x, } // Check for zero lhs/rhs dim size. - if (xla::ShapeUtil::HasZeroElements(x_shape) || - xla::ShapeUtil::HasZeroElements(y_shape)) { + if (xla::ShapeUtil::IsZeroElementArray(x_shape) || + xla::ShapeUtil::IsZeroElementArray(y_shape)) { std::vector dimensions(batch_dimension_numbers.size()); for (int i = 0; i < batch_dimension_numbers.size(); ++i) { dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index 3f1384bc864abd882ebba2b90acbe0b1e664687a..20925118bf598a6436c43bd727ce40e3abafc46c 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -110,7 +110,6 @@ xla::StatusOr CholeskyUnblocked(xla::XlaBuilder* builder, FloatLiteral(body_builder, a_shape.element_type(), 0.5)); // a[..., i+1:, i] - auto ip1 = body_builder->Add(i, body_builder->ConstantR0(1)); // select the whole i-th column, then mask out all rows above i+1 TF_ASSIGN_OR_RETURN( auto a_0i, DynamicSliceInMinorDims(body_builder, body_a, {i}, {1})); diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 43e1c1e9fecec1c71db1509757251cb5d903ca49..b43405a1a407b5fa98dd740c62af91e048cc9490 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -22,21 +22,34 @@ limitations under the License. namespace tensorflow { -Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) { - xla::Shape literal_shape; - TF_RETURN_IF_ERROR(TensorShapeToXLAShape( - host_tensor.dtype(), host_tensor.shape(), &literal_shape)); +Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, + xla::BorrowingLiteral* literal) { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(), + host_tensor.shape(), &xla_shape)); + *literal = xla::BorrowingLiteral( + static_cast(DMAHelper::base(&host_tensor)), xla_shape); + return Status::OK(); +} - *literal = xla::Literal(literal_shape); +Status HostTensorsToBorrowingLiteralTuple( + tensorflow::gtl::ArraySlice host_tensors, + xla::BorrowingLiteral* literal) { + std::vector buf_ptrs; + buf_ptrs.reserve(host_tensors.size()); + std::vector tensor_shapes(host_tensors.size()); - // memcpy over the payload ... - // TODO(phawkins): handle string types. - size_t total_bytes = host_tensor.TotalBytes(); - if (total_bytes > 0) { - void* dst_ptr = literal->untyped_data(); - const void* src_ptr = DMAHelper::base(&host_tensor); - memcpy(dst_ptr, src_ptr, total_bytes); + for (int i = 0; i < host_tensors.size(); i++) { + // Validate runtime shapes and fail if it doesn't match the contract. + const Tensor* tensor = &host_tensors[i]; + buf_ptrs.emplace_back(static_cast(DMAHelper::base(tensor))); + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(tensor->dtype(), tensor->shape(), + &tensor_shapes[i])); } + + *literal = xla::BorrowingLiteral( + buf_ptrs, xla::ShapeUtil::MakeTupleShape(tensor_shapes)); + return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index 220bec15538c36fa30abef9e729b64dbbb9f72b3..ab7e861f3336097d2ea52487092f16edb5c14531 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -22,12 +22,20 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { -// Copies 'host_tensor' to an XLA Literal. Fails if host_tensor is of an -// unsupported type. -Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal); +// Returns a BorrowingLiteral that utilizes the same underlying buffer owned by +// 'host_tensor'. +Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, + xla::BorrowingLiteral* literal); + +// Returns a BorrowingLiteral tuple that utilizes the same underlying buffers +// owned by 'host_tensors'. +Status HostTensorsToBorrowingLiteralTuple( + tensorflow::gtl::ArraySlice host_tensors, + xla::BorrowingLiteral* literal); // Copies 'literal' to freshly allocated 'host_tensor', which is allocated of // type . diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD index bb9168fa358154f3db9dab87bacc9bf28dd16406..ace6fd1d8eeaf439509a7b75d8d986997c392e73 100644 --- a/tensorflow/compiler/tf2xla/ops/BUILD +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -8,12 +8,7 @@ load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") cc_library( name = "xla_ops", - srcs = [ - "dynamic_slice_ops.cc", - "functional_ops.cc", - "reduce_window_op.cc", - "sendrecv_ops.cc", - ], + srcs = ["xla_ops.cc"], deps = [ "//tensorflow/core:framework", ], diff --git a/tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc deleted file mode 100644 index d6c0edbb889b1751ac9d9d47d0c9534b543196ff..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" - -namespace tensorflow { - -REGISTER_OP("XlaDynamicUpdateSlice") - .Input("input: T") - .Input("update: T") - .Input("indices: Tindices") - .Output("output: T") - .Attr("T: type") - .Attr("Tindices: {int32, int64}") - .SetShapeFn(shape_inference::UnchangedShape) - .Doc(R"doc( -Wraps the XLA DynamicUpdateSlice operator, documented at - https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice -. - -XlaDynamicUpdateSlice generates a result which is the value of the `input` -operand, with a slice update overwritten at `indices`. The shape of `update` -determines the shape of the sub-array of the result which is updated. The shape -of indices must be rank == 1, with dimension size equal to the rank of `input`. - -Handling of out-of-bounds slice indices is implementation-defined. - -input: A `Tensor` of type T. -indices: A vector of indices into `input`. Must have length equal to the rank of - `input`. -update: A `Tensor` of type T. Same rank as `input`. -output: A `Tensor` of type T. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/functional_ops.cc b/tensorflow/compiler/tf2xla/ops/functional_ops.cc deleted file mode 100644 index 4a669f8e6eaf644f119f3c0a66f29d9f2c9a9d16..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/ops/functional_ops.cc +++ /dev/null @@ -1,74 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" - -namespace tensorflow { - -// TODO(b/37549631) setting the While Op to always be stateful is too -// conservative. -REGISTER_OP("XlaWhile") - .Input("input: T") - .Output("output: T") - .Attr("T: list(type) >= 0") - .Attr("cond: func") - .Attr("body: func") - .SetIsStateful() - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -output = input; While (Cond(output)) { output = Body(output) } - -input: A list of input tensors whose types are T. -output: A list of output tensors whose types are T. -cond: A function takes 'input' and returns a tensor. If the tensor is - a scalar of non-boolean, the scalar is converted to a boolean - according to the following rule: if the scalar is a numerical - value, non-zero means True and zero means False; if the scalar is - a string, non-empty means True and empty means False. If the - tensor is not a scalar, non-emptiness means True and False - otherwise. -body: A function that takes a list of tensors and returns another - list of tensors. Both lists have the same types as specified by T. -)doc"); - -// TODO(b/37549631) setting the If Op to always be stateful is too -// conservative. -REGISTER_OP("XlaIf") - .Input("cond: Tcond") - .Input("inputs: Tin") - .Output("output: Tout") - .Attr("Tcond: type") - .Attr("then_branch: func") - .Attr("else_branch: func") - .Attr("Tin: list(type) >= 0") - .Attr("Tout: list(type) >= 0") - .SetIsStateful() - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -output = cond ? then_branch(inputs) : else_branch(inputs). - -cond: A boolean scalar. -inputs: A list of input tensors. -output: A list of tensors returned by either then_branch(inputs) or - else_branch(inputs). The input shapes of the then_branch and - else_branch must match. -then_branch: A function takes 'inputs' and returns a list of tensors, - whose types are the same as what else_branch returns. -else_branch: A function takes 'inputs' and returns a list of tensors. - whose types are the same as what then_branch returns. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/reduce_window_op.cc b/tensorflow/compiler/tf2xla/ops/reduce_window_op.cc deleted file mode 100644 index d9af982adc090ea78c711fd4656ba429c53b18c9..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/ops/reduce_window_op.cc +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" - -namespace tensorflow { - -REGISTER_OP("XlaReduceWindow") - .Input("input: T") - .Input("init_value: T") - .Attr("T: numbertype") - .Attr("computation: func") - .Attr("window_dimensions: list(int)") - .Attr("window_strides: list(int)") - .Attr("padding_low: list(int)") - .Attr("padding_high: list(int)") - .Output("output: T") - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -Wraps the XLA ReduceWindow operator, documented at - https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . - -input: the input tensor -init_value: a scalar representing the initial value for the reduction -computation: a reducer function to apply -window_dimensions: the shape of the window -window_strides: the inter-window strides -padding_low: the padding to apply at the start of each input dimensions -padding_high: the padding to apply at the end of each input dimension. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc deleted file mode 100644 index 7ec7b50e905a6cbdecea4543dcb87322b5a7e844..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" - -namespace tensorflow { - -REGISTER_OP("XlaSend") - .Input("tensor: T") - .Attr("T: type") - .Attr("tensor_name: string") - .SetIsStateful() - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -Sends the named tensor to another XLA computation. Wraps the XLA Send operator -documented at - https://www.tensorflow.org/performance/xla/operation_semantics#send . - -tensor: The tensor to send. -tensor_name: A string key that identifies the channel. -)doc"); - -REGISTER_OP("XlaRecv") - .Output("tensor: dtype") - .Attr("dtype: type") - .Attr("tensor_name: string") - .Attr("shape: shape") - .SetIsStateful() - .SetShapeFn([](shape_inference::InferenceContext* c) { - TensorShape shape_attr; - TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr)); - shape_inference::ShapeHandle s; - TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s)); - c->set_output(0, s); - return Status::OK(); - }) - .Doc(R"doc( -Receives the named tensor from another XLA computation. Wraps the XLA Recv -operator documented at - https://www.tensorflow.org/performance/xla/operation_semantics#recv . - -tensor: The tensor to receive. -dtype: The type of the tensor. -tensor_name: A string key that identifies the channel. -shape: The shape of the tensor. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..a59c77f5c3a309abe8f6fbab1e48455d54e8fae5 --- /dev/null +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -0,0 +1,182 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("XlaDynamicUpdateSlice") + .Input("input: T") + .Input("update: T") + .Input("indices: Tindices") + .Output("output: T") + .Attr("T: type") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Wraps the XLA DynamicUpdateSlice operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice +. + +XlaDynamicUpdateSlice generates a result which is the value of the `input` +operand, with a slice update overwritten at `indices`. The shape of `update` +determines the shape of the sub-array of the result which is updated. The shape +of indices must be rank == 1, with dimension size equal to the rank of `input`. + +Handling of out-of-bounds slice indices is implementation-defined. + +input: A `Tensor` of type T. +indices: A vector of indices into `input`. Must have length equal to the rank of + `input`. +update: A `Tensor` of type T. Same rank as `input`. +output: A `Tensor` of type T. +)doc"); + +// TODO(b/37549631) setting the If Op to always be stateful is too +// conservative. +REGISTER_OP("XlaIf") + .Input("cond: Tcond") + .Input("inputs: Tin") + .Output("output: Tout") + .Attr("Tcond: type") + .Attr("then_branch: func") + .Attr("else_branch: func") + .Attr("Tin: list(type) >= 0") + .Attr("Tout: list(type) >= 0") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +output = cond ? then_branch(inputs) : else_branch(inputs). + +cond: A boolean scalar. +inputs: A list of input tensors. +output: A list of tensors returned by either then_branch(inputs) or + else_branch(inputs). The input shapes of the then_branch and + else_branch must match. +then_branch: A function takes 'inputs' and returns a list of tensors, + whose types are the same as what else_branch returns. +else_branch: A function takes 'inputs' and returns a list of tensors. + whose types are the same as what then_branch returns. +)doc"); + +REGISTER_OP("XlaRecv") + .Output("tensor: dtype") + .Attr("dtype: type") + .Attr("tensor_name: string") + .Attr("shape: shape") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* c) { + TensorShape shape_attr; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr)); + shape_inference::ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s)); + c->set_output(0, s); + return Status::OK(); + }) + .Doc(R"doc( +Receives the named tensor from another XLA computation. Wraps the XLA Recv +operator documented at + https://www.tensorflow.org/performance/xla/operation_semantics#recv . + +tensor: The tensor to receive. +dtype: The type of the tensor. +tensor_name: A string key that identifies the channel. +shape: The shape of the tensor. +)doc"); + +REGISTER_OP("XlaReduceWindow") + .Input("input: T") + .Input("init_value: T") + .Attr("T: numbertype") + .Attr("computation: func") + .Attr("window_dimensions: list(int)") + .Attr("window_strides: list(int)") + .Attr("padding_low: list(int)") + .Attr("padding_high: list(int)") + .Output("output: T") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Wraps the XLA ReduceWindow operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . + +input: the input tensor +init_value: a scalar representing the initial value for the reduction +computation: a reducer function to apply +window_dimensions: the shape of the window +window_strides: the inter-window strides +padding_low: the padding to apply at the start of each input dimensions +padding_high: the padding to apply at the end of each input dimension. +)doc"); + +REGISTER_OP("XlaSend") + .Input("tensor: T") + .Attr("T: type") + .Attr("tensor_name: string") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Sends the named tensor to another XLA computation. Wraps the XLA Send operator +documented at + https://www.tensorflow.org/performance/xla/operation_semantics#send . + +tensor: The tensor to send. +tensor_name: A string key that identifies the channel. +)doc"); + +REGISTER_OP("XlaSort") + .Input("input: T") + .Output("output: T") + .Attr("T: type") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Wraps the XLA Sort operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#sort +. + +Sorts a tensor. Currently only rank 1 sorts in ascending order are supported. + +input: A `Tensor` of type T. +output: A `Tensor` of type T. +)doc"); + +// TODO(b/37549631) setting the While Op to always be stateful is too +// conservative. +REGISTER_OP("XlaWhile") + .Input("input: T") + .Output("output: T") + .Attr("T: list(type) >= 0") + .Attr("cond: func") + .Attr("body: func") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +output = input; While (Cond(output)) { output = Body(output) } + +input: A list of input tensors whose types are T. +output: A list of output tensors whose types are T. +cond: A function takes 'input' and returns a tensor. If the tensor is + a scalar of non-boolean, the scalar is converted to a boolean + according to the following rule: if the scalar is a numerical + value, non-zero means True and zero means False; if the scalar is + a string, non-empty means True and empty means False. If the + tensor is not a scalar, non-emptiness means True and False + otherwise. +body: A function that takes a list of tensors and returns another + list of tensors. Both lists have the same types as specified by T. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index e5ce65bec950fdfd38c3ca5bc62ac745ef8ca4a7..2fc47dffb8f5f16f24e3beb1ff75aeed3e857c58 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -77,4 +77,6 @@ def reduce_window(operand, recv = gen_xla_ops.xla_recv send = gen_xla_ops.xla_send +sort = gen_xla_ops.xla_sort + while_loop = gen_xla_ops.xla_while diff --git a/tensorflow/compiler/tf2xla/validate_control_flow.cc b/tensorflow/compiler/tf2xla/validate_control_flow.cc new file mode 100644 index 0000000000000000000000000000000000000000..1b3be4cfa4aff2e6eb551eaf7dc8bbce5c433f9e --- /dev/null +++ b/tensorflow/compiler/tf2xla/validate_control_flow.cc @@ -0,0 +1,84 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/validate_control_flow.h" + +#include + +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace { +// Information about a loop frame structure. +struct Frame { + string name; + + // Pointer to the parent frame. The root frame has a pointer to itself. + Frame* parent = nullptr; + + // The loop condition of the loop. There should be exactly one loop condition + // in every loop. + const Node* loop_cond = nullptr; +}; + +// Verify that the ControlFlowInfo of the graph has valid loop structure. +Status ValidateControlFlowInfo(const Graph* graph, + const std::vector& cf_info) { + std::unordered_map frames; + for (const Node* node : graph->op_nodes()) { + const ControlFlowInfo& cf = cf_info[node->id()]; + if (!cf.frame || !cf.parent_frame) { + // Skip nodes unreachable from the source node. They might be pruned + // later. + continue; + } + + Frame& frame = frames[cf.frame_name]; + Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name]; + if (frame.parent == nullptr) { + frame.parent = parent; + frame.name = cf.frame_name; + } else if (frame.parent != parent) { + return errors::InvalidArgument( + "Invalid loop structure: Mismatched parent frames for \"", + cf.frame_name, "\": \"", parent->name, "\" vs \"", frame.parent->name, + "\". This is an internal bug, please file a bug report with " + "instructions on how to reproduce the error."); + } + if (IsLoopCond(node)) { + if (frame.loop_cond) { + return errors::InvalidArgument( + "Invalid loop structure: Loop \"", cf.frame_name, + "\" has more than one LoopCond node: \"", node->name(), "\" and \"", + frame.loop_cond->name(), + "\". This is an internal bug, please file a bug report with " + "instructions on how to reproduce the error."); + } + frame.loop_cond = node; + } + } + return Status::OK(); +} +} // namespace + +Status BuildAndValidateControlFlowInfo(const Graph* graph, + std::vector* info, + std::vector* unreachable_nodes) { + TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, info, unreachable_nodes)); + return ValidateControlFlowInfo(graph, *info); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/validate_control_flow.h b/tensorflow/compiler/tf2xla/validate_control_flow.h new file mode 100644 index 0000000000000000000000000000000000000000..74159dc9291bf57612553b5491ea90ed4d681556 --- /dev/null +++ b/tensorflow/compiler/tf2xla/validate_control_flow.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_VALIDATE_CONTROL_FLOW_H_ +#define TENSORFLOW_COMPILER_TF2XLA_VALIDATE_CONTROL_FLOW_H_ + +#include + +#include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Populate the control flow frame info of each node in the graph. Verify that +// the graph has well-formed control flow strcuture that can be functionalized. +// If unreachable_nodes is not nullptr, append to it the names of nodes +// unreachable from the source node. +Status BuildAndValidateControlFlowInfo( + const Graph* graph, std::vector* info, + std::vector* unreachable_nodes = nullptr); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_VALIDATE_CONTROL_FLOW_H_ diff --git a/tensorflow/compiler/tf2xla/validate_control_flow_test.cc b/tensorflow/compiler/tf2xla/validate_control_flow_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..74c9f4b86cae440ee17134ab7dd6af9fcc97134e --- /dev/null +++ b/tensorflow/compiler/tf2xla/validate_control_flow_test.cc @@ -0,0 +1,131 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/validate_control_flow.h" + +#include +#include + +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/cc/ops/while_loop.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { +Status LessThanTenCond(const Scope& scope, const std::vector& inputs, + Output* output) { + *output = ops::Less(scope, inputs[0], 10); + return scope.status(); +} + +Status AddOneBody(const Scope& scope, const std::vector& inputs, + std::vector* outputs) { + outputs->push_back(ops::AddN(scope, {inputs[0], 1})); + return scope.status(); +} + +Status NestedLoopBody(const Scope& scope, const std::vector& inputs, + std::vector* outputs) { + return ops::BuildWhileLoop(scope.NewSubScope("inner"), inputs, + LessThanTenCond, AddOneBody, "inner_loop", + outputs); +} + +TEST(ValidateControlFlowTest, InputsFromDifferentFrames) { + Scope scope = Scope::NewRootScope().ExitOnError(); + std::vector inputs; + inputs.push_back(ops::Placeholder(scope, DT_INT32)); + std::vector outputs; + TF_ASSERT_OK(ops::BuildWhileLoop(scope.NewSubScope("outer"), inputs, + LessThanTenCond, NestedLoopBody, + "outer_loop", &outputs)); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + // {inner/Enter', 'outer/Switch'} --> 'inner/Merge'. 'inner/Enter' is in frame + // 'inner_loop'. 'outer/Switch' is in frame 'outer_loop'. + std::vector info; + Status status = BuildAndValidateControlFlowInfo(graph.get(), &info); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE(str_util::StrContains(status.error_message(), + "has inputs from different frames")) + << status.error_message(); +} + +TEST(ValidateControlFlowTest, MismatchedParentFrames) { + Scope scope = Scope::NewRootScope().ExitOnError(); + std::vector inputs; + inputs.push_back(ops::Placeholder(scope, DT_INT32)); + std::vector outputs; + TF_ASSERT_OK(ops::BuildWhileLoop(scope, inputs, LessThanTenCond, AddOneBody, + "test_loop", &outputs)); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + Node* enter_1 = nullptr; + for (Node* node : graph->op_nodes()) { + if (IsEnter(node)) { + enter_1 = node; + } + } + ASSERT_TRUE(enter_1 != nullptr); + + NodeDef enter; + enter.set_name("Enter2"); + enter.set_op("Enter"); + (*enter.mutable_attr())["T"].set_type(DT_INT32); + (*enter.mutable_attr())["frame_name"].set_s("test_loop"); + *enter.add_input() = "Enter"; + Status status; + Node* enter_2 = graph->AddNode(enter, &status); + TF_ASSERT_OK(status); + graph->AddControlEdge(enter_1, enter_2); + + // SOURCE("") --> Enter("test_loop") --> Enter2("test_loop") + // For node 'Enter', the parent frame of "test_loop" is empty. + // For node 'Enter2', the parent frame of "test_loop" is "test_loop". + std::vector info; + status = BuildAndValidateControlFlowInfo(graph.get(), &info); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE( + str_util::StrContains(status.error_message(), "Mismatched parent frames")) + << status.error_message(); +} + +TEST(ValidateControlFlowTest, TwoLoopCond) { + // Test that one frame has at most one LoopCond node. This is necessary for + // functionalize control flow. + Scope scope = Scope::NewRootScope().ExitOnError(); + std::vector inputs; + inputs.push_back(ops::Placeholder(scope, DT_INT32)); + std::vector outputs; + TF_ASSERT_OK(ops::BuildWhileLoop(scope, inputs, LessThanTenCond, AddOneBody, + "test_loop", &outputs)); + outputs.clear(); + TF_ASSERT_OK(ops::BuildWhileLoop(scope.NewSubScope("sub"), inputs, + LessThanTenCond, AddOneBody, "test_loop", + &outputs, false)); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + std::vector info; + Status status = BuildAndValidateControlFlowInfo(graph.get(), &info); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE(str_util::StrContains(status.error_message(), + "more than one LoopCond node")) + << status.error_message(); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index a8bd199675b4ad3f14960abe1e981442b9432663..9c8e56a17e07348d3cfaaca0b5eb335295af05c3 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -652,6 +652,7 @@ Status XlaCompiler::CompileSingleOp( .Finalize(graph.get(), &node); TF_RETURN_IF_ERROR(status); } + FixupSourceAndSinkEdges(graph.get()); return CompileGraph(options, name, std::move(graph), args, result); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index c93850ce270502ea1df1f6469963e96e86994fa2..6be74957c6a92004ca9fdc97747d7b6cb693dd28 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -52,13 +52,7 @@ class XlaContext; // (kind kResource). // // Only kParameter and initialized kResource arguments become runtime parameters -// to the generated XLA computation. The XLA computation will have run-time -// parameters in the following order: -// +---------------------+-----------------------------------------+ -// | kParameter values | Initial values of kResource arguments | -// +---------------------+-----------------------------------------+ -// Within each block, the arguments are arranged by the _Arg index from which -// they were derived. +// to the generated XLA computation. // // The run-time outputs of the XLA computation are arranged in the following // order: @@ -77,10 +71,10 @@ class XlaContext; // tensors with a different shape to their representation inside the XLA // computation. // -// In both inputs and outputs, kResource values are placed the end. When +// In computation outputs, updated kResource values are placed the end. When // emitting While loop bodies, we must ensure that the loop body has -// identical input and output signatures. By moving variable values -// to the end of the argument list and using the +// identical input and output signatures. By passing variable values +// at the end of the argument list and using the // `return_updated_values_for_all_variables` option, we can ensure that the // input and output values of resources appear at the same positions. // @@ -234,7 +228,8 @@ class XlaCompiler { tf2xla::HostComputeMetadata host_compute_metadata; // Resources whose values were updated by the computation, ordered - // by return value position. Resource updates follow the non-constant + // by return value position (which is the same as the order the resources + // were passed as arguments). Resource updates follow the non-constant // results in the outputs of XLA computation. std::vector resource_updates; diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 5fbf4b952c6e6f1f50a9ccdabce370874b9fdfd2..613230452b74755ce7543ec2ab82861aa0dfeb7a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -1049,5 +1050,42 @@ TEST_F(XlaCompilerTest, NodeWithInvalidDataType) { << status.error_message(); } +TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + NodeDef no_op; + no_op.set_name("NoOp"); + no_op.set_op("NoOp"); + Status status; + graph->AddNode(no_op, &status); + TF_ASSERT_OK(status); + + std::vector args; + XlaCompiler compiler(DefaultOptions()); + // No control edge linking NoOp with source/sink. + { + std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); + CopyGraph(*graph, graph_copy.get()); + XlaCompiler::CompilationResult result; + status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", + std::move(graph_copy), args, &result); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(str_util::StrContains(status.error_message(), + "The following nodes are unreachable " + "from the source in the graph: NoOp")) + << status.error_message(); + } + + // Fix control edges for NoOp. + { + std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); + CopyGraph(*graph, graph_copy.get()); + EXPECT_TRUE(FixupSourceAndSinkEdges(graph_copy.get())); + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", + std::move(graph_copy), args, &result)); + EXPECT_EQ(0, result.resource_updates.size()); + } +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 098072d33cd4eb7f7dec0ec4196b43eca0220d4a..67174b251d3acc381321a0097921fa5c695267fe 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -92,7 +92,7 @@ void XlaContext::AddRetval(int retval_index, DataType type, } Status XlaContext::AddConstRetval(int retval_index, DataType dtype, - const xla::Literal& literal) { + const xla::LiteralSlice& literal) { VLOG(1) << "Adding retval index " << retval_index << " with non-data-dependent tensor to XLA computation"; if (retvals_.size() <= retval_index) { diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 341bf6ff1f37fa7cd81f41c02a941214067b1bd1..5960daaefd625a0b4daf00d7b8c929f3c856575f 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -83,7 +83,7 @@ class XlaContext : public ResourceBase { // As for Retval, but for return values that are compile-time constants. Status AddConstRetval(int retval_index, DataType dtype, - const xla::Literal& literal); + const xla::LiteralSlice& literal); // Creates a resource with resource `kind` and initial value `handle`. `name` // is a descriptive name for use in error messages. See the `XlaResource` diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index f1594193af09c7193f03b4685d3a7d4510d654dd..93cd340485649b5c55fd6771d24dd9e79989a1f5 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -210,8 +211,9 @@ Status XlaHelpers::Iota(xla::XlaBuilder* builder, DataType dtype, int64 size, return errors::InvalidArgument("Invalid argument type ", DataTypeString(dtype)); } - xla::Literal linspace_literal; - TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal)); + xla::BorrowingLiteral linspace_literal; + TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); + *iota = builder->ConstantLiteral(linspace_literal); return Status::OK(); } @@ -245,8 +247,9 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, return errors::InvalidArgument("Invalid argument type ", DataTypeString(index_type)); } - xla::Literal linspace_literal; - TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal)); + + xla::BorrowingLiteral linspace_literal; + TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); // Broadcast the linspace constant across the indices along the new axis, // and test equality at each position. diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 76c68d81af4dd9ec40fe6b1c33b03a876a0c6dc6..c6ddbcc6e1b0dfd558f5deb8412e3b55b3b71ef9 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/core/common_runtime/dma_helper.h" namespace tensorflow { @@ -87,6 +88,25 @@ Status XlaOpKernelContext::ConstantInputReshaped( } const XlaExpression* expression = CastExpressionFromTensor(tensor); + auto copy_tensor_to_literal = [](const Tensor& tensor, + xla::Literal* literal) { + xla::Shape literal_shape; + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), &literal_shape)); + + *literal = xla::Literal(literal_shape); + + // memcpy over the payload ... + // TODO(phawkins): handle string types. + size_t total_bytes = tensor.TotalBytes(); + if (total_bytes > 0) { + void* dst_ptr = literal->untyped_data(); + const void* src_ptr = DMAHelper::base(&tensor); + memcpy(dst_ptr, src_ptr, total_bytes); + } + return Status::OK(); + }; + // If the tensor has a known constant value, there is no need to invoke XLA. if (expression->has_constant_value()) { Tensor temp(tensor.dtype()); @@ -95,13 +115,15 @@ Status XlaOpKernelContext::ConstantInputReshaped( // with the enclosing Tensor. return errors::Internal("Incompatible shapes in ConstantInputReshaped."); } - return HostTensorToLiteral(temp, constant_literal); + + return copy_tensor_to_literal(temp, constant_literal); } // Make sure we treat zero-element tensors as constant. if (new_shape.num_elements() == 0) { Tensor temp(tensor.dtype(), new_shape); - return HostTensorToLiteral(temp, constant_literal); + + return copy_tensor_to_literal(temp, constant_literal); } xla::XlaOp handle = expression->handle(); @@ -162,7 +184,8 @@ Status XlaOpKernelContext::ConstantInputReshaped( } // Converts an int32 or int64 scalar literal to an int64. -static Status LiteralToInt64Scalar(const xla::Literal& literal, int64* out) { +static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal, + int64* out) { if (xla::ShapeUtil::Rank(literal.shape()) != 0) { return errors::InvalidArgument("value is not a scalar"); } @@ -177,7 +200,8 @@ static Status LiteralToInt64Scalar(const xla::Literal& literal, int64* out) { } // Converts an float32 or float64 scalar literal to a float64. -static Status LiteralToFloat64Scalar(const xla::Literal& literal, double* out) { +static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal, + double* out) { if (xla::ShapeUtil::Rank(literal.shape()) != 0) { return errors::InvalidArgument("value is not a scalar"); } @@ -204,7 +228,7 @@ Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) { } // Converts an int32 or int64 1D literal to an int64 vector. -static Status LiteralToInt64Vector(const xla::Literal& literal, +static Status LiteralToInt64Vector(const xla::LiteralSlice& literal, std::vector* out) { if (xla::ShapeUtil::Rank(literal.shape()) != 1) { return errors::InvalidArgument("value is not 1D"); @@ -368,8 +392,9 @@ void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { const TensorShape& shape = constant.shape(); - xla::Literal literal; - OP_REQUIRES_OK(context_, HostTensorToLiteral(constant, &literal)); + xla::BorrowingLiteral literal; + OP_REQUIRES_OK(context_, HostTensorToBorrowingLiteral(constant, &literal)); + xla::XlaOp handle = builder()->ConstantLiteral(literal); CHECK_NE(handle.builder(), nullptr); diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 4692038b61f6871a8a16299fd4d11e963eb46a57..ee6da6a67a70441bc1cb3a164a623fa389ed03cb 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -71,16 +71,18 @@ XlaOpRegistry::~XlaOpRegistry() = default; << " have incompatible allow_resource_types settings."; return false; } - if (!x.has_device_whitelist || !y.has_device_whitelist) { - LOG(WARNING) << "Registrations of " << x.name - << " do not both have device whitelists."; + if (!x.has_device_whitelist && !y.has_device_whitelist) { + LOG(WARNING) << "Duplicate registrations of " << x.name + << "with no device whitelists."; return false; } - for (const auto& device : x.device_whitelist) { - if (y.device_whitelist.count(device) != 0) { - LOG(WARNING) << "Multiple registrations of " << x.name << " on device " - << device; - return false; + if (x.has_device_whitelist && y.has_device_whitelist) { + for (const auto& device : x.device_whitelist) { + if (y.device_whitelist.count(device) != 0) { + LOG(WARNING) << "Multiple registrations of " << x.name << " on device " + << device; + return false; + } } } if (x.compile_time_constant_inputs != y.compile_time_constant_inputs) { @@ -157,97 +159,135 @@ void XlaOpRegistry::RegisterCompilationKernels() { registry.jit_kernels_registered_ = true; OpRegistryInterface* op_registry = OpRegistry::Global(); - for (const auto& op : registry.ops_) { - const string& op_name = op.first; - const std::unique_ptr& op_registration = op.second; - const OpDef* op_def; - Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def); - if (!lookup_status.ok()) { - LOG(ERROR) << lookup_status.error_message(); - XLA_LOG_LINES( - ERROR, "Ops registered: \n" + - dynamic_cast(op_registry)->DebugString(true)); + // Order of op registration: + // The goal is to allow the co-existence of backend-specific kernels and + // generic kernels. To achieve this, we enforce the following order of + // registrations for one op: + // 1. Process op registration with device whitelists: + // this pass registers backend-specific kernels for this op. + // 2. Process op registration without device whitelists: + // this pass registers the kernels for all the other supported backends. + for (auto& ops : registry.ops_) { + const string& op_name = ops.first; + std::vector>& op_registrations = ops.second; + // Partition the op registration so that the ones with device whitelists + // precede the one without device whitelist. + std::partition(op_registrations.begin(), op_registrations.end(), + [](const std::unique_ptr& op_reg) { + return op_reg->has_device_whitelist; + }); + + // Collect a set of backend registered by ops with device whitelists. + // The op registration without whitelists will register a generic kernel + // for all other backends not in this set. + std::unordered_set whitelisted_backend; + for (auto& op_registration : op_registrations) { + if (op_registration->has_device_whitelist) { + whitelisted_backend.insert(op_registration->device_whitelist.begin(), + op_registration->device_whitelist.end()); + } } - TF_CHECK_OK(lookup_status); - std::unordered_set type_attrs; - for (const OpDef::AttrDef& attr_def : op_def->attr()) { - if (attr_def.type() == "type" || attr_def.type() == "list(type)") { - type_attrs.insert(attr_def.name()); + for (auto& op_registration : op_registrations) { + const OpDef* op_def; + Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def); + if (!lookup_status.ok()) { + LOG(ERROR) << lookup_status.error_message(); + XLA_LOG_LINES( + ERROR, + "Ops registered: \n" + + dynamic_cast(op_registry)->DebugString(true)); } - } + TF_CHECK_OK(lookup_status); - // Checks there are no type constraints referring to unknown attributes. - for (const auto& constraint : op_registration->type_constraints) { - if (type_attrs.find(constraint.first) == type_attrs.end()) { - LOG(FATAL) << "Unknown type attribute " << constraint.first - << " in XLA op registration for " << op_name; + std::unordered_set type_attrs; + for (const OpDef::AttrDef& attr_def : op_def->attr()) { + if (attr_def.type() == "type" || attr_def.type() == "list(type)") { + type_attrs.insert(attr_def.name()); + } } - } - for (auto& backend : registry.backends_) { - // If the operator has a device whitelist, only register on whitelisted - // devices. - if (op_registration->has_device_whitelist && - op_registration->device_whitelist.find(backend.first) == - op_registration->device_whitelist.end()) { - continue; + // Checks there are no type constraints referring to unknown attributes. + for (const auto& constraint : op_registration->type_constraints) { + if (type_attrs.find(constraint.first) == type_attrs.end()) { + LOG(FATAL) << "Unknown type attribute " << constraint.first + << " in XLA op registration for " << op_name; + } } - std::unique_ptr kdef(new KernelDef); - kdef->set_op(op_registration->name); - kdef->set_device_type(backend.first); - - // Constrain each type attribute to the intersection of: - // a) the types supported by the backend, and - // b) the types allowed by the OpDef, and - // c) the type constraints. - for (const string& type_attr : type_attrs) { - KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint(); - attr_constraint->set_name(type_attr); - auto* allowed_values = - attr_constraint->mutable_allowed_values()->mutable_list(); - - const OpDef::AttrDef& op_def_attr = *FindAttr(type_attr, *op_def); - const auto* op_def_allowed_types = - op_def_attr.has_allowed_values() - ? &op_def_attr.allowed_values().list().type() - : nullptr; - auto constraint_it = op_registration->type_constraints.find(type_attr); - const std::set* type_constraints = - constraint_it != op_registration->type_constraints.end() - ? &constraint_it->second - : nullptr; - for (DataType dtype : backend.second.supported_types) { - // Filter out types that aren't allowed by the OpDef. - if (op_def_allowed_types != nullptr && - std::find(op_def_allowed_types->begin(), - op_def_allowed_types->end(), - dtype) == op_def_allowed_types->end()) { - continue; + for (auto& backend : registry.backends_) { + // If the operator has a device whitelist, only register on whitelisted + // devices. + if (op_registration->has_device_whitelist && + op_registration->device_whitelist.find(backend.first) == + op_registration->device_whitelist.end()) { + continue; + } + + // If the operator does NOT has a device whitelist, skip all devices + // that has already been registered. + if (!op_registration->has_device_whitelist && + whitelisted_backend.find(backend.first) != + whitelisted_backend.end()) { + continue; + } + + std::unique_ptr kdef(new KernelDef); + kdef->set_op(op_registration->name); + kdef->set_device_type(backend.first); + + // Constrain each type attribute to the intersection of: + // a) the types supported by the backend, and + // b) the types allowed by the OpDef, and + // c) the type constraints. + for (const string& type_attr : type_attrs) { + KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint(); + attr_constraint->set_name(type_attr); + auto* allowed_values = + attr_constraint->mutable_allowed_values()->mutable_list(); + + const OpDef::AttrDef& op_def_attr = *FindAttr(type_attr, *op_def); + const auto* op_def_allowed_types = + op_def_attr.has_allowed_values() + ? &op_def_attr.allowed_values().list().type() + : nullptr; + auto constraint_it = + op_registration->type_constraints.find(type_attr); + const std::set* type_constraints = + constraint_it != op_registration->type_constraints.end() + ? &constraint_it->second + : nullptr; + for (DataType dtype : backend.second.supported_types) { + // Filter out types that aren't allowed by the OpDef. + if (op_def_allowed_types != nullptr && + std::find(op_def_allowed_types->begin(), + op_def_allowed_types->end(), + dtype) == op_def_allowed_types->end()) { + continue; + } + // Filter out types based on the type constraints. + if (type_constraints != nullptr && + type_constraints->find(dtype) == type_constraints->end()) { + continue; + } + // Passed all the filters, this type is allowed. + allowed_values->add_type(dtype); } - // Filter out types based on the type constraints. - if (type_constraints != nullptr && - type_constraints->find(dtype) == type_constraints->end()) { - continue; + if (op_registration->allow_resource_types) { + allowed_values->add_type(DT_RESOURCE); } - // Passed all the filters, this type is allowed. - allowed_values->add_type(dtype); } - if (op_registration->allow_resource_types) { - allowed_values->add_type(DT_RESOURCE); + if (backend.second.op_filter != nullptr && + !backend.second.op_filter(kdef.get())) { + continue; } + VLOG(2) << "XLA op registration: device: " << backend.first + << " op: " << op_name; + registry.kernel_registrars_.emplace_back( + new kernel_factory::OpKernelRegistrar( + new KernelDef(*kdef), "XlaJitOp", op_registration->factory)); + backend.second.kernel_defs.push_back(std::move(kdef)); } - if (backend.second.op_filter != nullptr && - !backend.second.op_filter(kdef.get())) { - continue; - } - VLOG(2) << "XLA op registration: device: " << backend.first - << " op: " << op_name; - registry.kernel_registrars_.emplace_back( - new kernel_factory::OpKernelRegistrar( - new KernelDef(*kdef), "XlaJitOp", op_registration->factory)); - backend.second.kernel_defs.push_back(std::move(kdef)); } } } @@ -265,12 +305,12 @@ std::vector XlaOpRegistry::DeviceKernels( << "Unknown backend " << compilation_device_name; for (const std::unique_ptr& k : it->second.kernel_defs) { auto op_iter = registry.ops_.find(k->op()); - CHECK(op_iter != registry.ops_.end()); + CHECK(op_iter != registry.ops_.end() && !op_iter->second.empty()); // The test in IsCompatible ensures that if there are multiple matching // registrations for this op name, they all have the same value of // compilation_only, so only the first match needs to be tested. if (include_compilation_only_kernels || - !op_iter->second->compilation_only) { + !op_iter->second.front()->compilation_only) { kernels.push_back(k.get()); } } @@ -282,10 +322,13 @@ XlaOpRegistry::CompileTimeConstantInputs(const string& op) { XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); auto it = registry.ops_.find(op); - if (it == registry.ops_.end()) { + if (it == registry.ops_.end() || it->second.empty()) { return nullptr; } - return &it->second->compile_time_constant_inputs; + // The test in IsCompatible ensures that if there are multiple matching + // registrations for this op name, they all have the same value of + // compile_time_constant_inputs, so only the first match is returned. + return &it->second.front()->compile_time_constant_inputs; } std::vector XlaOpRegistry::BackendNames() { @@ -378,16 +421,15 @@ XlaOpRegistrar::XlaOpRegistrar( std::unique_ptr registration) { XlaOpRegistry& registry = XlaOpRegistry::Instance(); mutex_lock lock(registry.mutex_); - auto existing_ops = registry.ops_.equal_range(registration->name); - for (auto existing = existing_ops.first; existing != existing_ops.second; - ++existing) { - if (!XlaOpRegistry::IsCompatible(*existing->second, *registration)) { + auto& existing_ops = registry.ops_[registration->name]; + for (auto& existing : existing_ops) { + if (!XlaOpRegistry::IsCompatible(*existing, *registration)) { LOG(FATAL) << "XLA op registration " << registration->name << " is incompatible with existing registration of the same name."; } } - registry.ops_.emplace(registration->name, std::move(registration)); + existing_ops.emplace_back(std::move(registration)); } XlaBackendRegistrar::XlaBackendRegistrar( diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index e255b01dd7fdcb095c7992d4352d2d9bb7d36ac3..2d4593ea4999ad6d8cd0f0e2eec9c6d69c3020b8 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -203,7 +203,7 @@ class XlaOpRegistry { // Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP. // Registrations present under the same key must satisfy IsCompatible above, // and this is checked during registration. - std::unordered_multimap> ops_ + std::unordered_map>> ops_ GUARDED_BY(mutex_); // Have we already registered the JIT kernels on the JIT devices? diff --git a/tensorflow/compiler/tf2xla/xla_op_registry_test.cc b/tensorflow/compiler/tf2xla/xla_op_registry_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a2ec8dc730e6daf6ffd4c9ea71567c8b23e5e310 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_op_registry_test.cc @@ -0,0 +1,86 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +// This test is to verify the correctness of XLA op registration with specific +// backend overrides. + +// A dummy backend-specific OpKernel for CPU. +class DummyCPUOp : public XlaOpKernel { + public: + explicit DummyCPUOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + ctx->SetOutput(0, ctx->Input(0)); + } +}; + +// A dummy generic OpKernel for all backends. +class DummyGenericOp : public XlaOpKernel { + public: + explicit DummyGenericOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + ctx->SetOutput(0, ctx->Input(0)); + } +}; + +REGISTER_OP("DummyDuplicateOp") + .Attr("T: {float, int32}") + .Input("input: int32") + .Output("output: int32") + .Doc(R"doc( +A dummy Op. + +input: dummy input. +output: dummy output. +)doc"); + +// Register the DummyCPUOp kernel for CPU with type INT32. +REGISTER_XLA_OP(Name("DummyDuplicateOp") + .Device(DEVICE_CPU_XLA_JIT) + .TypeConstraint("T", DT_INT32), + DummyCPUOp); +// Register the DummyGeneric kernel for all registered device (except CPU since +// it is already registered), with type FLOAT. +REGISTER_XLA_OP(Name("DummyDuplicateOp").TypeConstraint("T", DT_FLOAT), + DummyGenericOp); + +// Test the correctness of registered kernels. The kernel registered for CPU +// should have type INT32 while all other kernels should have type FLOAT. +TEST(XlaOpRegistryTest, XlaOpRegistrationWithOverride) { + XlaOpRegistry::RegisterCompilationKernels(); + auto registered_kernels = GetAllRegisteredKernels(); + for (const auto& kernels : registered_kernels) { + if (kernels.op() == "DummyDuplicateOp") { + EXPECT_EQ(kernels.constraint_size(), 1); + EXPECT_EQ(kernels.constraint(0).name(), "T"); + if (kernels.device_type() == "XLA_CPU_JIT") { + EXPECT_EQ(kernels.constraint(0).allowed_values().list().type(0), + DT_INT32); + } else { + EXPECT_EQ(kernels.constraint(0).allowed_values().list().type(0), + DT_FLOAT); + } + } + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index c6deb959a59f7b79500a0948b4035ea56cd9b4a1..4525197146b7f29f405650bdb08e5946cbce8114 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -53,7 +53,6 @@ xla_proto_library( deps = [ ":xla_data_proto", "//tensorflow/compiler/xla/service:hlo_proto", - "//tensorflow/compiler/xla/service:session_proto", ], ) @@ -310,7 +309,6 @@ cc_library( ":types", ":util", ":xla_data_proto", - "//tensorflow/core:framework", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index c4f0c4468fccee87818974d0cefe26983179dcf5..8f08d3b2e04670ad6590aca1db0fd9d25faed83f 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -110,6 +110,7 @@ cc_library( "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:source_map_util", diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc index dc69d2097ebe14ca0e14a39849d4fcae99024fdc..5c9abad4c3126be5e45e96c770c0679fe8606788 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.cc +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -24,7 +24,8 @@ namespace xla { StatusOr>> CompileOnlyClient::CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options) { + const AotCompilationOptions& options, + std::unique_ptr* metadata) { std::vector service_instances; service_instances.reserve(computations.size()); for (const AotXlaComputationInstance& instance : computations) { @@ -36,7 +37,8 @@ CompileOnlyClient::CompileAheadOfTime( service_instance.argument_layouts = instance.argument_layouts; service_instance.result_layout = instance.result_layout; } - return compiler_service_->CompileAheadOfTime(service_instances, options); + return compiler_service_->CompileAheadOfTime(service_instances, options, + metadata); } int64 CompileOnlyClient::PointerSizeForTriple(tensorflow::StringPiece triple) { diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h index f9a7c31270c7a11175f47a537639a97d0c9211af..332c96503637344d56e363e19db4880c37ca9684 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.h +++ b/tensorflow/compiler/xla/client/compile_only_client.h @@ -46,13 +46,15 @@ class CompileOnlyClient : public Client { const Shape* result_layout; }; - // Compiles a list of xla computations for ahead-of-time execution. This is - // intended for use in static compilation. The |options| parameter describes - // the target for which the compiler should emit code. + // Compiles a list of xla computations for ahead-of-time execution. + // This is intended for use in static compilation. The |options| + // parameter describes the target for which the compiler should emit + // code. |metadata|, if provided, is populated during compilation. StatusOr>> CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options); + const AotCompilationOptions& options, + std::unique_ptr* metadata = nullptr); // Returns the size of a pointer in bytes for a given triple. static int64 PointerSizeForTriple(tensorflow::StringPiece triple); diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index a1d34796ccfd86f2025eff0ecb51338eb6a9b1da..f095ec92131c05355144c01539bd19bad0e77baf 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -121,4 +121,133 @@ StatusOr Any(const XlaOp& predicates, XlaBuilder* builder) { return builder->Reduce(predicates, f, logical_or, all_dimensions); } +namespace { +xla::XlaOp FloatLiteral(xla::XlaBuilder* b, PrimitiveType data_type, + float value) { + return b->ConvertElementType(b->ConstantR0(value), data_type); +} + +// Polynomials for computing erf/erfc. Originally from cephes. +// Note we use float for compatibility across devices, at the cost of some +// precision for 64 bit computations. +// +// Coefficients are in descending order. +std::array kErfcPCoefficient = { + 2.46196981473530512524E-10, 5.64189564831068821977E-1, + 7.46321056442269912687E0, 4.86371970985681366614E1, + 1.96520832956077098242E2, 5.26445194995477358631E2, + 9.34528527171957607540E2, 1.02755188689515710272E3, + 5.57535335369399327526E2}; +std::array kErfcQCoefficient = { + 1.00000000000000000000E0, 1.32281951154744992508E1, + 8.67072140885989742329E1, 3.54937778887819891062E2, + 9.75708501743205489753E2, 1.82390916687909736289E3, + 2.24633760818710981792E3, 1.65666309194161350182E3, + 5.57535340817727675546E2}; +std::array kErfcRCoefficient = { + 5.64189583547755073984E-1, 1.27536670759978104416E0, + 5.01905042251180477414E0, 6.16021097993053585195E0, + 7.40974269950448939160E0, 2.97886665372100240670E0}; +std::array kErfcSCoefficient = { + 1.00000000000000000000E0, 2.26052863220117276590E0, + 9.39603524938001434673E0, 1.20489539808096656605E1, + 1.70814450747565897222E1, 9.60896809063285878198E0, + 3.36907645100081516050E0}; +std::array kErfTCoefficient = { + 9.60497373987051638749E0, 9.00260197203842689217E1, + 2.23200534594684319226E3, 7.00332514112805075473E3, + 5.55923013010394962768E4}; +std::array kErfUCoefficient = { + 1.00000000000000000000E0, 3.35617141647503099647E1, + 5.21357949780152679795E2, 4.59432382970980127987E3, + 2.26290000613890934246E4, 4.92673942608635921086E4}; +} // namespace + +// Evaluate the polynomial given coefficients and `x`. +// N.B. Coefficients should be supplied in decreasing order. +xla::XlaOp EvaluatePolynomial(xla::XlaBuilder* b, const xla::XlaOp& x, + tensorflow::gtl::ArraySlice coefficients, + PrimitiveType data_type) { + xla::XlaOp poly = FloatLiteral(b, data_type, 0.0); + for (float c : coefficients) { + poly = b->Add(b->Mul(poly, x), FloatLiteral(b, data_type, c)); + } + return poly; +} + +// Compute an approximation of the error function complement (1 - erf(x)). +xla::XlaOp Erfc(xla::XlaBuilder* b, const xla::XlaOp& x, + PrimitiveType data_type) { + xla::XlaOp zero = FloatLiteral(b, data_type, 0.0); + xla::XlaOp two = FloatLiteral(b, data_type, 2.0); + xla::XlaOp eight = FloatLiteral(b, data_type, 8.0); + + xla::XlaOp abs_x = b->Abs(x); + xla::XlaOp z = b->Exp(b->Mul(b->Neg(x), x)); + + xla::XlaOp pp = EvaluatePolynomial(b, abs_x, kErfcPCoefficient, data_type); + xla::XlaOp pq = EvaluatePolynomial(b, abs_x, kErfcQCoefficient, data_type); + xla::XlaOp pr = EvaluatePolynomial(b, abs_x, kErfcRCoefficient, data_type); + xla::XlaOp ps = EvaluatePolynomial(b, abs_x, kErfcSCoefficient, data_type); + + xla::XlaOp y = b->Select(b->Lt(abs_x, eight), b->Div(b->Mul(z, pp), pq), + b->Div(b->Mul(z, pr), ps)); + + return b->Select(b->Lt(x, zero), b->Sub(two, y), y); +} + +// Compute a polynomial approximation of the error function. +xla::XlaOp Erf(xla::XlaBuilder* b, const xla::XlaOp& x, + PrimitiveType data_type) { + xla::XlaOp z = b->Mul(x, x); + xla::XlaOp pt = EvaluatePolynomial(b, z, kErfTCoefficient, data_type); + xla::XlaOp pu = EvaluatePolynomial(b, z, kErfUCoefficient, data_type); + return b->Div(b->Mul(x, pt), pu); +} + +// Approximation for the inverse error function from +// Giles, M., "Approximating the erfinv function". +// The approximation has the form: +// w = -log((1 - x) * (1 + x)) +// if ( w < 5 ) { +// w = w - 2.5 +// p = sum_{i=1}^n lq[i]*w^i +// } else { +// w = sqrt(w) - 3 +// p = sum_{i=1}^n gq[i]*w^i +// } +// return p*x +StatusOr ErfInv(xla::XlaBuilder* b, const xla::XlaOp& x) { + TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x)); + constexpr int kDegree = 9; + constexpr std::array w_less_than_5_constants = { + 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, + -4.39150654e-06f, 0.00021858087f, -0.00125372503f, + -0.00417768164f, 0.246640727f, 1.50140941f}; + constexpr std::array w_greater_than_5_constants = { + -0.000200214257f, 0.000100950558f, 0.00134934322f, + -0.00367342844f, 0.00573950773f, -0.0076224613f, + 0.00943887047f, 1.00167406f, 2.83297682f}; + + auto one = b->ConstantR0(1.0); + auto w = b->Neg(b->Log(b->Mul(b->Sub(one, x), b->Add(one, x)))); + + auto lt = b->Lt(w, b->ConstantR0(5.0)); + auto coefficient = [&](int i) { + return b->Select( + lt, + b->Broadcast(b->ConstantR0(w_less_than_5_constants[i]), + AsInt64Slice(shape.dimensions())), + b->Broadcast(b->ConstantR0(w_greater_than_5_constants[i]), + AsInt64Slice(shape.dimensions()))); + }; + w = b->Select(lt, b->Sub(w, b->ConstantR0(2.5f)), + b->Sub(b->SqrtF32(w), b->ConstantR0(3.0f))); + auto p = coefficient(0); + for (int i = 1; i < kDegree; ++i) { + p = b->Add(coefficient(i), b->Mul(p, w)); + } + return b->Mul(p, x); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index 64b6b7d63353165e45bf12d35126a7eeef9e56e4..efdcc7e198c4ae73a69802ed8f04c6c048d902dc 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -55,6 +55,23 @@ XlaComputation CreateScalarOrComputation(XlaBuilder* builder); // Note: if predicates is zero-sized, Any() vacuously returns false. StatusOr Any(const XlaOp& predicates, XlaBuilder* builder); +// Evaluate the polynomial given coefficients and `x`. +// N.B. Coefficients should be supplied in decreasing order. +xla::XlaOp EvaluatePolynomial(xla::XlaBuilder* b, const xla::XlaOp& x, + tensorflow::gtl::ArraySlice coefficients, + PrimitiveType data_type); + +// Compute an approximation of the error function complement (1 - erf(x)). +xla::XlaOp Erfc(xla::XlaBuilder* b, const xla::XlaOp& x, + PrimitiveType data_type); + +// Compute an approximation of the error function. +xla::XlaOp Erf(xla::XlaBuilder* b, const xla::XlaOp& x, + PrimitiveType data_type); + +// Compute an approximation of the inverse of the error function. +StatusOr ErfInv(xla::XlaBuilder* b, const xla::XlaOp& x); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_ diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index f9003373a6e809a57855e249cfc255b913fb8bc0..5f9710914bd0ceff55f5b0a2db05e553ce8bd637 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -51,24 +51,17 @@ LocalExecutable::LocalExecutable(std::unique_ptr executable, Status LocalExecutable::ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, const ExecutableRunOptions& run_options, const Backend& backend) { - const ComputationLayout& host_computation_layout = - executable_->module_config().host_entry_computation_layout(); - const ComputationLayout& device_computation_layout = - executable_->module_config().device_entry_computation_layout(); + const ComputationLayout& computation_layout = + executable_->module_config().entry_computation_layout(); // Check argument number, shapes, and layouts. - if (arguments.size() != host_computation_layout.parameter_count()) { + if (arguments.size() != computation_layout.parameter_count()) { return InvalidArgument( "invalid number of arguments for computation: expected %d, got %zu", - host_computation_layout.parameter_count(), arguments.size()); - } - if (arguments.size() != device_computation_layout.parameter_count()) { - return InvalidArgument( - "invalid number of arguments for computation: expected %d, got %zu", - device_computation_layout.parameter_count(), arguments.size()); + computation_layout.parameter_count(), arguments.size()); } for (int i = 0; i < arguments.size(); ++i) { - if (!host_computation_layout.parameter_layout(i).MatchesLayoutInShape( + if (!computation_layout.parameter_layout(i).MatchesLayoutInShape( arguments[i]->on_host_shape())) { return InvalidParameterArgument( executable_.get(), i, @@ -76,24 +69,10 @@ Status LocalExecutable::ValidateExecutionOptions( "parameter " "%d: want %s, got %s", i, - ShapeUtil::HumanString( - host_computation_layout.parameter_layout(i).shape()) + ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape()) .c_str(), ShapeUtil::HumanString(arguments[i]->on_host_shape()).c_str()); } - if (!device_computation_layout.parameter_layout(i).MatchesLayoutInShape( - arguments[i]->on_device_shape())) { - return InvalidParameterArgument( - executable_.get(), i, - "Argument does not match device shape or layout of computation " - "parameter " - "%d: want %s, got %s", - i, - ShapeUtil::HumanString( - device_computation_layout.parameter_layout(i).shape()) - .c_str(), - ShapeUtil::HumanString(arguments[i]->on_device_shape()).c_str()); - } } if (run_options.stream() != nullptr) { @@ -185,7 +164,7 @@ StatusOr LocalExecutable::Run( run_options, backend_->StreamBorrower(), backend_->eigen_intra_op_thread_pool()); - if (executable_->dumping()) { + if (executable_->dumping_snapshot()) { return ExecuteAndDump(&service_options, arguments); } return executable_->ExecuteOnStreamWrapper( @@ -195,45 +174,44 @@ StatusOr LocalExecutable::Run( StatusOr LocalExecutable::ExecuteAndDump( const ServiceExecutableRunOptions* run_options, const tensorflow::gtl::ArraySlice arguments) { - executable_->session_module()->set_execution_platform( + executable_->hlo_snapshot()->set_execution_platform( backend_->platform()->Name()); - TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->session_module())); + TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->hlo_snapshot())); TF_ASSIGN_OR_RETURN( ScopedShapedBuffer result, executable_->ExecuteOnStream(run_options, arguments, /*hlo_execution_profile=*/nullptr)); - TF_RETURN_IF_ERROR(RecordResult(&result, executable_->session_module())); - TF_RETURN_IF_ERROR(executable_->DumpSessionModule()); + TF_RETURN_IF_ERROR(RecordResult(&result, executable_->hlo_snapshot())); + TF_RETURN_IF_ERROR(executable_->DumpHloSnapshot()); return std::move(result); } Status LocalExecutable::RecordArguments( const tensorflow::gtl::ArraySlice arguments, - SessionModule* session_module) { - session_module->clear_arguments(); + HloSnapshot* hlo_snapshot) { + hlo_snapshot->clear_arguments(); for (const ShapedBuffer* argument : arguments) { TF_ASSIGN_OR_RETURN(std::unique_ptr literal, LiteralFromShapedBuffer(*argument)); - *session_module->add_arguments() = literal->ToProto(); + *hlo_snapshot->add_arguments() = literal->ToProto(); } return Status::OK(); } Status LocalExecutable::RecordResult(const ShapedBuffer* result, - SessionModule* session_module) { - session_module->clear_result(); + HloSnapshot* hlo_snapshot) { + hlo_snapshot->clear_result(); TF_ASSIGN_OR_RETURN(std::unique_ptr literal, LiteralFromShapedBuffer(*result)); - *session_module->mutable_result() = literal->ToProto(); + *hlo_snapshot->mutable_result() = literal->ToProto(); return Status::OK(); } StatusOr> LocalExecutable::LiteralFromShapedBuffer( const ShapedBuffer& shaped_buffer) { - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - backend_->stream_executor(shaped_buffer.device_ordinal())); - return backend_->transfer_manager()->TransferLiteralFromDevice(executor, + TF_ASSIGN_OR_RETURN(auto stream, + backend_->BorrowStream(shaped_buffer.device_ordinal())); + return backend_->transfer_manager()->TransferLiteralFromDevice(stream.get(), shaped_buffer); } @@ -288,19 +266,18 @@ StatusOr LocalClient::LiteralToShapedBuffer( TF_ASSIGN_OR_RETURN(auto scoped_buffer, backend().transfer_manager()->AllocateScopedShapedBuffer( literal.shape(), allocator, device_ordinal)); - TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, - backend().stream_executor(device_ordinal)); + TF_ASSIGN_OR_RETURN(auto stream, + mutable_backend()->BorrowStream(device_ordinal)); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - executor, literal, scoped_buffer)); + stream.get(), literal, scoped_buffer)); return std::move(scoped_buffer); } StatusOr> LocalClient::ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer) { - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - backend().stream_executor(shaped_buffer.device_ordinal())); - return backend().transfer_manager()->TransferLiteralFromDevice(executor, + TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream( + shaped_buffer.device_ordinal())); + return backend().transfer_manager()->TransferLiteralFromDevice(stream.get(), shaped_buffer); } diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 5b408cc6b246282e362c7ce2eade369e3d18044d..4d9e0d7cd9d6ddebead1e12b23e94b529038039b 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" @@ -78,11 +79,10 @@ class LocalExecutable { // proto. Status RecordArguments( const tensorflow::gtl::ArraySlice arguments, - SessionModule* session_module); + HloSnapshot* hlo_snapshot); // Records the result of the computation in a SessionModule proto. - Status RecordResult(const ShapedBuffer* result, - SessionModule* session_module); + Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot); // Returns a literal containing the contents of the given ShapedBuffer. StatusOr> LiteralFromShapedBuffer( diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index ae506317c2e4862d77cb4f0628e919871ad1aeb2..d7ebcf8bebc1f656b4965c833e0d42ccceb1b99f 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -1611,14 +1611,40 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, }); } -XlaOp XlaBuilder::CrossReplicaSum(const XlaOp& operand) { +XlaOp XlaBuilder::CrossReplicaSum( + const XlaOp& operand, + tensorflow::gtl::ArraySlice replica_group_ids) { return NoteErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); + const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {}); + auto b = CreateSubBuilder("sum"); + b->Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"), + b->Parameter(/*parameter_number=*/1, scalar_shape, "y")); + TF_ASSIGN_OR_RETURN(auto computation, b->Build()); + return CrossReplicaSum(operand, computation, replica_group_ids, + /*channel_id=*/tensorflow::gtl::nullopt); + }); +} +XlaOp XlaBuilder::CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice replica_group_ids, + const tensorflow::gtl::optional& channel_id) { + return NoteErrorOrReturn([&]() -> StatusOr { + if (channel_id.has_value()) { + return Unimplemented("channel_id is not supported in AllReduce"); + } + + HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferCrossReplicaSumShape({&operand_shape})); + for (int64 replica_group_id : replica_group_ids) { + instr.add_replica_group_ids(replica_group_id); + } + + AddCalledComputation(computation, &instr); return AddInstruction(std::move(instr), HloOpcode::kCrossReplicaSum, {operand}); diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index 2b3013a91c488782098bd81994e899eae5a1f506..0329e42ed1aef8edd1537e888ddcd78f08584407 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -528,9 +528,35 @@ class XlaBuilder { tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding); - // Returns the sum of the operand value across all replicas. All replicas - // supply one input to the sum and all replicas receive the resulting sum. - XlaOp CrossReplicaSum(const XlaOp& operand); + // Returns the sum of the operand value within each subgroup of replicas. All + // replicas supply one input to the sum and all replicas receive the resulting + // sum for each subgroup. + XlaOp CrossReplicaSum( + const XlaOp& operand, + tensorflow::gtl::ArraySlice replica_group_ids = {}); + + // Enqueues an operation that do an AllReduce of the operand cross cores. Here + // AllReduce means doing a reduction on the input operand cross cores and then + // broadcasting the reduction result to those cores. The reduction function is + // defined by `computation`, which should be a commutative computation on + // scalars, e.g., add, min, or max. The way that AllReduce is applied is + // configured by: + // + // - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all + // replicas belong to one group. Allreduce will be applied within subgroups. + // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, + // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. + // + // - `channel_id`: for Allreduce nodes from different models, if they have the + // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be + // applied cross models. + // + // TODO(b/79737069): Rename this to AllReduce when it's ready to use. + XlaOp CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice replica_group_ids = {}, + const tensorflow::gtl::optional& channel_id = + tensorflow::gtl::nullopt); // Enqueues an operation that scatters the `source` array to the selected // indices of each window. diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/BUILD b/tensorflow/compiler/xla/experimental/xla_sharding/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..a26b20c861846501c911253d89619591c37322b3 --- /dev/null +++ b/tensorflow/compiler/xla/experimental/xla_sharding/BUILD @@ -0,0 +1,18 @@ +# Description: +# Python API for shardings in XLA. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +py_library( + name = "xla_sharding", + srcs = ["xla_sharding.py"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto_py", + "//tensorflow/compiler/xla/python_api:types", + "//tensorflow/compiler/xla/python_api:xla_shape", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py new file mode 100644 index 0000000000000000000000000000000000000000..abd10b164eaef8e75ed304483861baf250c5b954 --- /dev/null +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py @@ -0,0 +1,204 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ====================================== +"""Experimental support for defining XLA shardings.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import numpy as np + +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.compiler.xla.python_api import xla_shape +from tensorflow.core.framework import attr_value_pb2 + + +class Sharding(object): + """A class to support adding sharding attributes to Ops. + + Use the factory constructors and then call apply_to_tensor: + Sharding.replicate().apply_to_tensor(tensor) + """ + + def __init__(self, proto=None): + """Do not use this constructor; use the factory functions below.""" + self._proto = proto + + @classmethod + def replicate(cls): + """Returns a replicated sharding attribute. + + This causes an op to be computed in its entirety independently on all + cores in the XLA device. + """ + return Sharding( + proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED)) + + @classmethod + def assign_device(cls, core): + """Returns an AssignDevice sharding attribute. + + This causes an op to be computed in its entirety only on one core in + the XLA device. + Args: + core: The core to assign this Op to. + """ + return Sharding( + proto=xla_data_pb2.OpSharding( + type=xla_data_pb2.OpSharding.MAXIMAL, + tile_assignment_dimensions=[1], + tile_assignment_devices=[core])) + + @classmethod + def tile(cls, tile_shape, tile_assignment): + """Returns a Tiled sharding attribute. + + This causes an op to be partially computed on multiple cores in the + XLA device. + + Args: + tile_shape: A xla_shape.Shape describing the tile shape that each core + will compute. + The tile shape does not need to be divisible by the tile assignment. + tile_assignment: An np.ndarray describing the topology of the tiling and + which device will compute which part of the topology. + + Raises: + TypeError: tile_assignment was not of np.array type or tile_shape was + not of xla_shape.Shape type. + + TODO(jmolloy): This concept is nefarious and is not + something we really want to expose to users (especially as the + contract for tile_assignment is very strict). + """ + if not isinstance(tile_assignment, np.ndarray): + raise TypeError('Tile assignment must be of type np.ndarray') + if not isinstance(tile_shape, xla_shape.Shape): + raise TypeError('Tile shape must be of type xla_shape.Shape') + dims = list(tile_assignment.shape) + flattened_devices = tile_assignment.reshape(-1, order='C') + return Sharding( + proto=xla_data_pb2.OpSharding( + type=xla_data_pb2.OpSharding.OTHER, + tile_shape=tile_shape.message, + tile_assignment_dimensions=dims, + tile_assignment_devices=list(flattened_devices))) + + @classmethod + def split(cls, tensor, split_dimension, num_devices): + """Returns a Sharding that splits a tensor across a dimension. + + This creates a Tiled attribute, similar to tile(), but easier to use for the + common case of tiling a tensor N ways in one dimension. + + Args: + tensor: A tf.Tensor to split. + split_dimension: The dimension number to split. + num_devices: The number of cores to split `tensor` over. + + Raises: + ValueError: The tensor to split was smaller in the split dimension than + the number of devices to split over. + """ + tensor.shape.assert_is_fully_defined() + shape = tensor.shape.as_list() + if shape[split_dimension] < num_devices: + raise ValueError('Split dimension was smaller than the required number ' + 'of splits: shape=%r, dimension=%r, num_devices=%r', + shape, split_dimension, num_devices) + + tile_shape = shape + tile_shape[split_dimension] = int( + math.ceil(tile_shape[split_dimension] / num_devices)) + tile_shape_proto = xla_data_pb2.Shape( + element_type=xla_data_pb2.F32, dimensions=tile_shape) + + tile_assignment_dims = [1] * len(shape) + tile_assignment_dims[split_dimension] = num_devices + + return Sharding( + proto=xla_data_pb2.OpSharding( + type=xla_data_pb2.OpSharding.OTHER, + tile_shape=tile_shape_proto, + tile_assignment_dimensions=tile_assignment_dims, + tile_assignment_devices=range(num_devices))) + + def apply_to_tensor(self, tensor): + """Applies this Sharding attribute to `tensor`.""" + if len(tensor.op.outputs) > 1: + proto = self._get_or_create_tuple_proto(tensor.op) + # We can't mutate an element of old_proto.tuple_shardings, so create + # a new proto. + tuple_shardings = list(proto.tuple_shardings) + tuple_shardings[tensor.value_index] = self._proto + proto = xla_data_pb2.OpSharding( + type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings) + else: + proto = self._proto + + attr_value = attr_value_pb2.AttrValue(s=proto.SerializeToString()) + # TODO(jmolloy): This need to be seriously revisited before declaring this + # API available for public use. + # pylint: disable=protected-access + tensor.op._set_attr('_XlaSharding', attr_value) + + @property + def proto(self): + """Return the sharding protobuf of type xla_data_pb2.OpSharding.""" + return self._proto + + def _get_or_create_tuple_proto(self, op): + try: + attr = op.get_attr('_XlaSharding') + proto = xla_data_pb2.OpSharding() + proto.ParseFromString(attr) + return proto + except ValueError: + return self._create_tuple_proto(op) + + def _create_tuple_proto(self, op): + shardings = [ + xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED) + for _ in op.outputs + ] + return xla_data_pb2.OpSharding( + type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=shardings) + + +# Helpers for the above factory functions that allow easy application of +# shardings, for example: +# tensor = xla_sharding.replicate(tensor) + + +def replicate(tensor): + Sharding.replicate().apply_to_tensor(tensor) + return tensor + + +def assign_device(tensor, device): + Sharding.assign_device(device).apply_to_tensor(tensor) + return tensor + + +def tile(tensor, tile_shape, tile_assignment): + Sharding.tile(tile_shape, tile_assignment).apply_to_tensor(tensor) + return tensor + + +def split(tensor, split_dimension, num_devices): + Sharding.split(tensor, split_dimension, num_devices).apply_to_tensor(tensor) + return tensor diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index e8f29b83291a7cb238dc25b9f4bb743fe426a162..3f059cac30b5d36ab1d097bf200547533822e3d0 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -190,9 +190,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } if (!ShapeUtil::IsArray(shape)) { - return InvalidArgument( - "shape of primitive type %s should not have a layout", - PrimitiveType_Name(shape.element_type()).c_str()); + if (layout.minor_to_major_size() != 0 || + layout.padded_dimensions_size() != 0) { + return InvalidArgument( + "shape of primitive type %s should not have a non-trivial layout", + PrimitiveType_Name(shape.element_type()).c_str()); + } + return Status::OK(); } if (layout.format() == INVALID_FORMAT) { diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index bf9679cafec72c2e9dc5796e9058c6703239c508..2125ab7c61ab5e30fe51e16994e0da4883d509c4 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -606,8 +606,8 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, } // namespace Status EqualShapes(const Shape& expected, const Shape& actual) { - if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) { - return InvalidArgument("tupleness-mismatch! want: %s got %s", + if (expected.element_type() != actual.element_type()) { + return InvalidArgument("element type mismatch, want: %s got %s", ShapeUtil::HumanString(expected).c_str(), ShapeUtil::HumanString(actual).c_str()); } @@ -626,7 +626,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { return AppendStatus(result, StrCat("mismatch in tuple index", i)); } } - } else { + } else if (ShapeUtil::IsArray(expected)) { if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { return InvalidArgument("want rank of %s got rank of %s", ShapeUtil::HumanString(expected).c_str(), @@ -652,6 +652,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { } } } + // Non-array, non-tuple shapes are trivially equivalent. return Status::OK(); } @@ -705,6 +706,9 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { } break; } + case TOKEN: + // Tokens have no on-device representation and are trivially equal. + return Status::OK(); default: LOG(FATAL) << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: " diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 61afc311a702930a18be4842908f9a26b98d9a32..7c6a181b0a872bb03f1153017f16d1d06a99ecaa 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -148,8 +148,7 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { piece->emplace_back(std::move(child_piece)); } - } else { - CHECK(ShapeUtil::IsArray(shape)); + } else if (ShapeUtil::IsArray(shape)) { if (allocate_arrays) { if (LayoutUtil::IsSparseArray(shape)) { // For sparse arrays, the buffer must be of the size of the maximum @@ -165,6 +164,10 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { piece->set_buffer(new char[piece->size_bytes()]); } } + } else { + // If the shape is neither an array nor tuple, then it must be + // zero-sized. Otherwise, some memory needs to be allocated for it. + CHECK_EQ(piece->size_bytes(), 0); } } @@ -264,8 +267,8 @@ Status Literal::CopySliceFromInternal( StridedCopy(data(), linear_index(shape(), dest_base), 0, src_literal.data(), linear_index(src_literal.shape(), src_base), 0, 1); - } else if (!ShapeUtil::HasZeroElements(shape()) && - !ShapeUtil::HasZeroElements(src_literal.shape())) { + } else if (!ShapeUtil::IsZeroElementArray(shape()) && + !ShapeUtil::IsZeroElementArray(src_literal.shape())) { // Perform copy if neither src nor dest has dimensions with zero element, // otherwise it's a no-op. TF_RET_CHECK(src_base.size() == dest_base.size()); @@ -327,6 +330,10 @@ Status Literal::CopyElementFrom(const LiteralSlice& src_literal, return Status::OK(); } +/* static */ std::unique_ptr Literal::CreateToken() { + return MakeUnique(ShapeUtil::MakeTokenShape()); +} + std::vector Literal::DecomposeTuple() { CHECK(ShapeUtil::IsTuple(shape())); std::vector elements; @@ -379,7 +386,7 @@ void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, tensorflow::gtl::ArraySlice src, const Shape& dest_shape, const Shape& src_shape) { CHECK(ShapeUtil::Compatible(dest_shape, src_shape)); - if (ShapeUtil::HasZeroElements(dest_shape)) { + if (ShapeUtil::IsZeroElementArray(dest_shape)) { return; } std::vector index(ShapeUtil::Rank(dest_shape)); @@ -1177,7 +1184,7 @@ size_t LiteralBase::Hash() const { ShapeUtil::ForEachSubshape( shape(), [&](const Shape& subshape, const ShapeIndex& index) { - if (ShapeUtil::IsTuple(subshape)) { + if (!ShapeUtil::IsArray(subshape)) { return; } @@ -1368,6 +1375,11 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, return; } + if (ShapeUtil::IsToken(subshape)) { + pieces->push_back("token"); + return; + } + if (LayoutUtil::IsSparseArray(subshape)) { pieces->push_back(shape_to_string(subshape)); pieces->push_back("{"); @@ -1556,7 +1568,7 @@ string LiteralBase::ToString(bool print_layout) const { void LiteralBase::EachCellAsString( const std::function indices, const string& value)>& per_cell) const { - if (ShapeUtil::HasZeroElements(shape())) { + if (ShapeUtil::IsZeroElementArray(shape())) { return; } std::vector indices = IndexUtil::LinearIndexToMultidimensionalIndex( @@ -1962,7 +1974,7 @@ bool LiteralBase::IsAllFirst() const { // Empty shapes are not all the first element since there is no first // element. - if (ShapeUtil::HasZeroElements(piece.subshape())) { + if (ShapeUtil::IsZeroElementArray(piece.subshape())) { return false; } auto piece_is_all = [&]() { @@ -2341,28 +2353,27 @@ LiteralSlice::LiteralSlice(const LiteralBase& literal, : LiteralBase(), root_piece_(&literal.piece(view_root)) {} BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) - : LiteralBase(), shape_(shape) { - CHECK(ShapeUtil::IsArray(shape_)); - CHECK_NE(src_buf_ptr, nullptr); - CHECK(LayoutUtil::HasLayout(shape_)); + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(ShapeUtil::IsArray(*shape_)); + CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = Piece(); root_piece_.set_buffer(const_cast(src_buf_ptr)); - root_piece_.set_subshape(&shape_); + root_piece_.set_subshape(shape_.get()); } BorrowingLiteral::BorrowingLiteral( tensorflow::gtl::ArraySlice src_buf_ptrs, const Shape& shape) - : LiteralBase(), shape_(shape) { - CHECK(ShapeUtil::IsTuple(shape_)); - CHECK(!ShapeUtil::IsNestedTuple(shape_)); - CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(shape_)); + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(ShapeUtil::IsTuple(*shape_)); + CHECK(!ShapeUtil::IsNestedTuple(*shape_)); + CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); root_piece_ = Piece(); - root_piece_.set_subshape(&shape_); - BuildPieceSubtree(shape_, &root_piece_); + root_piece_.set_subshape(shape_.get()); + BuildPieceSubtree(*shape_, &root_piece_); for (int i = 0; i < src_buf_ptrs.size(); ++i) { - const auto& src_shape = shape_.tuple_shapes(i); + const auto& src_shape = shape_->tuple_shapes(i); CHECK(ShapeUtil::IsArray(src_shape)); root_piece_.child(i).set_buffer(const_cast(src_buf_ptrs[i])); } diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 1e26eb7ad4098bab1e757347a23edd73390b48b5..37ca8ea9f1d158b6bce8d5688288351f55c3b3c8 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -917,6 +917,9 @@ class Literal : public LiteralBase { return MakeTupleOwned(std::move(v)); } + // Create a constant token literal. Token types have no value. + static std::unique_ptr CreateToken(); + // Returns a vector containing the tuple elements of this Literal as separate // Literals. This Literal must be tuple-shaped and can be a nested tuple. The // elements are moved into the new Literals; no data is copied. Upon return @@ -1099,8 +1102,10 @@ class BorrowingLiteral : public LiteralBase { const Piece& root_piece() const override { return root_piece_; }; Piece root_piece_; - // Shape of this literal. - const Shape shape_; + // Shape of this literal. Stored as unique_ptr so such that the (default) + // move construction of this class would be trivially correct: the pointer to + // Shape root_piece_ stores will still point to the correct address. + std::unique_ptr shape_; }; template @@ -1454,7 +1459,7 @@ void LiteralBase::EachCell( std::function indices, NativeT value)> per_cell) const { - if (ShapeUtil::HasZeroElements(shape())) { + if (ShapeUtil::IsZeroElementArray(shape())) { return; } std::vector indices(ShapeUtil::Rank(shape()), 0); diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index f127cee0fdc126429ed423aace3b3b7764a05b2e..493d807591dd3c425293e4ee796bca3036a3088c 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -334,6 +334,22 @@ TEST_F(LiteralUtilTest, NonScalarEquality) { EXPECT_EQ(nil, nil); } +TEST_F(LiteralUtilTest, TokenEquality) { + auto token0 = Literal::CreateToken(); + auto token1 = Literal::CreateToken(); + auto scalar = Literal::CreateR0(1.0); + + EXPECT_EQ(*token0, *token1); + EXPECT_NE(*token0, *scalar); + + EXPECT_EQ(*Literal::MakeTuple({token0.get()}), + *Literal::MakeTuple({token0.get()})); + EXPECT_EQ(*Literal::MakeTuple({token0.get(), scalar.get()}), + *Literal::MakeTuple({token1.get(), scalar.get()})); + EXPECT_NE(*Literal::MakeTuple({token0.get(), scalar.get()}), + *Literal::MakeTuple({scalar.get(), token1.get()})); +} + TEST_F(LiteralUtilTest, DifferentLayoutEquality) { // Test equality with literals which have different layouts. auto colmajor = @@ -1431,7 +1447,7 @@ TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) { EXPECT_EQ(matrix_view, *Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); } -TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtrTest) { +TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) { std::vector int64_values = {1, 2, 3}; const Shape literal_shape = ShapeUtil::MakeShape(S64, {3}); @@ -1443,7 +1459,7 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtrTest) { EXPECT_EQ(literal.Get({2}), 3); } -TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrsTest) { +TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) { std::vector one_two_three = {1, 2, 3}; const Shape one_two_three_shape = ShapeUtil::MakeShape(S64, {3}); diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index 143c9a2366be5786b7ef2148580caeb97d67d2d8..b16147e3be71771269d8b7a18528bef3a8c72d99 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -85,5 +85,10 @@ PrimitiveType ComplexComponentType(PrimitiveType complex_type) { } } +bool IsArrayType(PrimitiveType primitive_type) { + return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE && + primitive_type != OPAQUE && primitive_type != TOKEN; +} + } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index b26a10ade63a5dad3bf8f9f3a2a33c3c5e67bdb2..889e9a1ceca675689406d255d348c82c398563aa 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -133,6 +133,9 @@ bool IsUnsignedIntegralType(PrimitiveType type); bool IsIntegralType(PrimitiveType type); +// Returns true if values of the given primitive type are held in array shapes. +bool IsArrayType(PrimitiveType primitive_type); + // Returns the number of bits in the representation for a given type. int BitWidth(PrimitiveType type); diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index f808990cadeab5fd2c4857920ee1daaac7262edd..29062348b0afd0f17bc24cef71f6d3929b131212 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/core/platform/thread_annotations.h" namespace xla { - namespace swig { // TODO(b/34473877) Ideally XLA would support AllReduce among arbitrary sets of @@ -97,6 +96,36 @@ const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const { return &shaped_buffer_; } +ShapedBuffer LocalShapedBuffer::Release() { return shaped_buffer_.release(); } + +LocalShapedBufferTuple::LocalShapedBufferTuple( + std::vector elements) + : elements_(std::move(elements)) { + for (auto* element : elements_) { + DCHECK(element != nullptr); + } +} + +LocalShapedBufferTuple::~LocalShapedBufferTuple() { + for (LocalShapedBuffer* element : elements_) { + if (element != nullptr) { + delete element; + } + } +} + +StatusOr LocalShapedBufferTuple::Release(int i) { + LocalShapedBuffer* element = elements_[i]; + if (element == nullptr) { + return InvalidArgument("Attempted to release already-released element %d.", + i); + } + elements_[i] = nullptr; + return element; +} + +int LocalShapedBufferTuple::size() const { return elements_.size(); } + static StatusOr ToBuffer(LocalClient* client, int device_ordinal, const Literal& arg) { @@ -315,11 +344,8 @@ LocalOp LocalComputationBuilder::Parameter(int64 parameter_number, return builder_.Parameter(parameter_number, shape, name); } -std::unique_ptr LocalComputationBuilder::GetShape( - const LocalOp& operand) { - auto result = MakeUnique(); - *result = builder_.GetShape(operand.op()).ValueOrDie(); - return result; +StatusOr LocalComputationBuilder::GetShape(const LocalOp& operand) { + return builder_.GetShape(operand.op()); } StatusOr LocalComputationBuilder::GetReturnValueShape() { @@ -598,10 +624,12 @@ _FORWARD_BINOP(Or) _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) +_FORWARD_UNOP(Expm1) _FORWARD_UNOP(Floor) _FORWARD_UNOP(Ceil) _FORWARD_UNOP(Round) _FORWARD_UNOP(Log) +_FORWARD_UNOP(Log1p) _FORWARD_UNOP(Sign) _FORWARD_UNOP(Cos) _FORWARD_UNOP(Sin) @@ -631,6 +659,54 @@ void DeleteLocalComputation(LocalComputation* computation) { delete computation; } -} // namespace swig +StatusOr DestructureLocalShapedBufferTuple( + LocalShapedBuffer* local_shaped_buffer) { + if (!ShapeUtil::IsTuple( + local_shaped_buffer->shaped_buffer()->on_device_shape())) { + return InvalidArgument( + "Attemped to destructure a LocalShapedBuffer that did not have a tuple " + "shape; shape: %s", + ShapeUtil::HumanString( + local_shaped_buffer->shaped_buffer()->on_device_shape()) + .c_str()); + } + DeviceMemoryAllocator* allocator = + local_shaped_buffer->shaped_buffer()->memory_allocator(); + ShapedBuffer tuple_buffer = local_shaped_buffer->Release(); + + // Extract some metadata we use to construct scoped buffers. + const se::Platform* platform = tuple_buffer.platform(); + int device_ordinal = tuple_buffer.device_ordinal(); + + ShapeTree& shape_tree = tuple_buffer.buffers(); + const Shape& tuple_shape = tuple_buffer.on_device_shape(); + std::vector results; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { + // Create a shaped buffer for this destructured tuple element. + const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i}); + VLOG(3) << "Starting tuple element " << i << " subshape: " << subshape; + ShapedBuffer shaped_buffer(subshape, subshape, platform, device_ordinal); + + ShapeUtil::ForEachSubshape( + subshape, [&](const Shape& s, const ShapeIndex& index) { + ShapeIndex original(index); + original.push_front(i); + se::DeviceMemoryBase* device_memory = + shape_tree.mutable_element(original); + shaped_buffer.set_buffer(*device_memory, index); + *device_memory = se::DeviceMemoryBase(); + }); + + VLOG(3) << "Completed tuple element: " << i; + results.push_back(new LocalShapedBuffer( + ScopedShapedBuffer(std::move(shaped_buffer), allocator))); + } + // Deallocate the root buffer. + se::DeviceMemoryBase root_buffer = tuple_buffer.root_buffer(); + TF_RETURN_IF_ERROR(allocator->Deallocate(device_ordinal, root_buffer)); + return new LocalShapedBufferTuple(std::move(results)); +} + +} // namespace swig } // namespace xla diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 9ac13b65231c932f152c1e79eb8e576cc6331fbd..95f0a0610b573479e0103ba2d1514844df35c2b4 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { - namespace swig { // Initializes the number of replicas that XLA will be initialized with (when @@ -69,10 +68,42 @@ class LocalShapedBuffer { StatusOr > ToLiteral() const; + // Transfers ownership of the encapsulated ShapedBuffer to the caller, + // analogous to std::unique_ptr::release(). + ShapedBuffer Release(); + private: ScopedShapedBuffer shaped_buffer_; }; +// Result of a tuple destructuring operation on a LocalShapedBuffer -- this +// appears to be a simpler mechanism for the time being than an alternative like +// using SWIG to transform std::vectors into Python lists of SWIG objects +// directly. +class LocalShapedBufferTuple { + public: + // Note: any LocalShapedBuffer elements that are not Release()'d will be + // deallocated in the destructor. + explicit LocalShapedBufferTuple(std::vector elements); + + ~LocalShapedBufferTuple(); + + // Releases the ith element to the caller. Further attempts to release the ith + // element will return an invalid argument error. + StatusOr Release(int i); + + // Returns the number of elements in the destructured tuple. + int size() const; + + private: + std::vector elements_; +}; + +// Destructures a tuple-valued LocalShapedBuffer into its constitutent elements +// in LocalShapedBufferTuple form. +StatusOr DestructureLocalShapedBufferTuple( + LocalShapedBuffer* local_shaped_buffer); + // Wraps a LocalExecutable produced by compiling a // LocalComputation. The Execute method forwards to that of the // underlying LocalExecutable, and additionally handles tranferring @@ -156,7 +187,7 @@ class LocalComputationBuilder { LocalOp Parameter(int64 parameter_number, const Shape& shape, const string& name); - std::unique_ptr GetShape(const LocalOp& operand); + StatusOr GetShape(const LocalOp& operand); // Returns the shape of the current return value for the computation. StatusOr GetReturnValueShape(); @@ -305,10 +336,12 @@ class LocalComputationBuilder { _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) + _FORWARD_UNOP(Expm1) _FORWARD_UNOP(Floor) _FORWARD_UNOP(Ceil) _FORWARD_UNOP(Round) _FORWARD_UNOP(Log) + _FORWARD_UNOP(Log1p) _FORWARD_UNOP(Sign) _FORWARD_UNOP(Cos) _FORWARD_UNOP(Sin) @@ -336,7 +369,6 @@ void DeleteCompiledLocalComputation(CompiledLocalComputation* computation); void DeleteLocalComputation(LocalComputation* computation); } // namespace swig - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 536b93c6f9381ae5c84e65eb7ed264b5eb158a72..477df6fde25d0db760e08df9d335bd12e31ccb55 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -200,6 +200,20 @@ tensorflow::ImportNumpy(); } } +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::LocalShapedBufferTuple*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + + %typemap(out) StatusOr< std::unique_ptr > { if ($1.ok()) { std::unique_ptr value = $1.ConsumeValueOrDie(); @@ -905,6 +919,9 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalShapedBuffer; %unignore xla::swig::LocalShapedBuffer::FromLiteral; %unignore xla::swig::LocalShapedBuffer::ToLiteral; +%unignore xla::swig::LocalShapedBufferTuple; +%unignore xla::swig::LocalShapedBufferTuple::Release; +%unignore xla::swig::LocalShapedBufferTuple::size; %unignore xla::swig::CompiledLocalComputation; %unignore xla::swig::CompiledLocalComputation::Execute; %unignore xla::swig::CompiledLocalComputation::ExecuteWithShapedBuffers; @@ -974,10 +991,12 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Not; %unignore xla::swig::LocalComputationBuilder::Abs; %unignore xla::swig::LocalComputationBuilder::Exp; +%unignore xla::swig::LocalComputationBuilder::Expm1; %unignore xla::swig::LocalComputationBuilder::Floor; %unignore xla::swig::LocalComputationBuilder::Ceil; %unignore xla::swig::LocalComputationBuilder::Round; %unignore xla::swig::LocalComputationBuilder::Log; +%unignore xla::swig::LocalComputationBuilder::Log1p; %unignore xla::swig::LocalComputationBuilder::Sign; %unignore xla::swig::LocalComputationBuilder::Cos; %unignore xla::swig::LocalComputationBuilder::Sin; @@ -989,6 +1008,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::ReciprocalF32; %unignore xla::swig::LocalComputationBuilder::Neg; %unignore xla::swig::LocalComputationBuilder::Sort; +%unignore xla::swig::DestructureLocalShapedBufferTuple; %unignore xla::swig::DeleteLocalShapedBuffer; %unignore xla::swig::DeleteLocalComputation; %unignore xla::swig::DeleteCompiledLocalComputation; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 11611ac61287da30548c335fac977bdc255396ed..c025127c3cf1871d4def1297ed36c046cae61d4b 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -89,10 +89,12 @@ _UNARY_OPS = [ 'Not', 'Abs', 'Exp', + 'Expm1', 'Floor', 'Round', 'Ceil', 'Log', + 'Log1p', 'Sign', 'Cos', 'Sin', @@ -184,6 +186,14 @@ class LocalBuffer(object): self._delete(self.c_local_shaped_buffer) self.c_local_shaped_buffer = None + def destructure(self): + assert self.c_local_shaped_buffer is not None + result = c_api.DestructureLocalShapedBufferTuple(self.c_local_shaped_buffer) + self.c_local_shaped_buffer = None + size = result.size() + destructured = tuple(LocalBuffer(result.Release(i)) for i in xrange(size)) + return destructured + def is_deleted(self): return self.c_local_shaped_buffer is None diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 375e720f9b433f45ad5adc329104c286184a7510..71e1d60a4e23dbfef333223c396e109533da9365 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -365,6 +365,55 @@ class LocalBufferTest(LocalComputationTest): with self.assertRaises(ValueError): compiled_c.ExecuteWithLocalBuffers([arg_buffer]) + def testDestructureTupleEmpty(self): + t = () + local_buffer = xla_client.LocalBuffer.from_pyval(t) + pieces = local_buffer.destructure() + self.assertTrue(local_buffer.is_deleted()) + self.assertEqual(len(pieces), 0) + + def testDestructureTupleOneArrayElement(self): + t = (np.array([1, 2, 3, 4], dtype=np.int32),) + local_buffer = xla_client.LocalBuffer.from_pyval(t) + pieces = local_buffer.destructure() + self.assertTrue(local_buffer.is_deleted()) + self.assertEqual(len(pieces), 1) + array = pieces[0] + got = array.to_py() + want = NumpyArrayS32([1, 2, 3, 4]) + np.testing.assert_equal(want, got) + + def testDestructureTupleTwoArrayElementDifferentType(self): + t = (np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32), + np.array([2, 3, 4, 5], dtype=np.int32)) + local_buffer = xla_client.LocalBuffer.from_pyval(t) + pieces = local_buffer.destructure() + self.assertTrue(local_buffer.is_deleted()) + self.assertEqual(len(pieces), 2) + array0, array1 = pieces + got = array0.to_py() + want = NumpyArrayF32([1.0, 2.0, 3.0, 4.0]) + np.testing.assert_equal(want, got) + got = array1.to_py() + want = NumpyArrayS32([2, 3, 4, 5]) + np.testing.assert_equal(want, got) + + def testDestructureTupleNested(self): + t = ((NumpyArrayF32([1.0, 2.0]), NumpyArrayS32([3, 4])), NumpyArrayS32([5])) + local_buffer = xla_client.LocalBuffer.from_pyval(t) + pieces = local_buffer.destructure() + self.assertTrue(local_buffer.is_deleted()) + self.assertEqual(len(pieces), 2) + tuple0, array1 = pieces + got = array1.to_py() + want = NumpyArrayS32([5]) + np.testing.assert_equal(want, got) + got = tuple0.to_py() + self.assertEqual(type(got), tuple) + self.assertEqual(len(got), 2) + np.testing.assert_equal(NumpyArrayF32([1.0, 2.0]), got[0]) + np.testing.assert_equal(NumpyArrayS32([3, 4]), got[1]) + class SingleOpTest(LocalComputationTest): """Tests for single ops. @@ -571,6 +620,12 @@ class SingleOpTest(LocalComputationTest): c.Exp(c.Constant(arr)) self._ExecuteAndCompareClose(c, expected=np.exp(arr)) + def testExpm1(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Expm1(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.expm1(arr)) + def testRound(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) @@ -583,6 +638,12 @@ class SingleOpTest(LocalComputationTest): c.Log(c.Constant(arr)) self._ExecuteAndCompareClose(c, expected=np.log(arr)) + def testLog1p(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Log1p(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.log1p(arr)) + def testNeg(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) diff --git a/tensorflow/compiler/xla/python_api/BUILD b/tensorflow/compiler/xla/python_api/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..8999cda5ef852d1246bea45a3312575ec1ac0721 --- /dev/null +++ b/tensorflow/compiler/xla/python_api/BUILD @@ -0,0 +1,36 @@ +# Description: +# Python API for XLA. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +py_library( + name = "types", + srcs = ["types.py"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto_py", + "//third_party/py/numpy", + ], +) + +py_library( + name = "xla_shape", + srcs = ["xla_shape.py"], + visibility = ["//visibility:public"], + deps = [ + ":types", + "//tensorflow/compiler/xla:xla_data_proto_py", + ], +) + +py_library( + name = "xla_literal", + srcs = ["xla_literal.py"], + visibility = ["//visibility:public"], + deps = [ + ":types", + ":xla_shape", + "//tensorflow/compiler/xla:xla_data_proto_py", + ], +) diff --git a/tensorflow/compiler/xla/python_api/types.py b/tensorflow/compiler/xla/python_api/types.py new file mode 100644 index 0000000000000000000000000000000000000000..b60f8dce92ace1b2c682374a2605b3a477936bbc --- /dev/null +++ b/tensorflow/compiler/xla/python_api/types.py @@ -0,0 +1,124 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ====================================== +"""Utilities for XLA-specific Python types.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import numpy as np + +from tensorflow.compiler.xla import xla_data_pb2 + +# Records corresponsence between a XLA primitive type and Python/Numpy types. +# +# primitive_type: value of type xla_data_pb2.PrimitiveType +# numpy_dtype: corresponsing Numpy "dtype" (like np.float32) +# literal_field_name: name of the field in the LiteralProto message elements +# of this type go into. +# literal_field_type: type of the field named 'literal_field_name'. +# +# TODO(eliben): figure out how to avoid knowing the extra Python type and the +# astype cast when writing into Literals. +TypeConversionRecord = collections.namedtuple('TypeConversionRecord', [ + 'primitive_type', 'numpy_dtype', 'literal_field_name', 'literal_field_type' +]) + +# Maps from XLA primitive types to TypeConversionRecord. +MAP_XLA_TYPE_TO_RECORD = { + xla_data_pb2.F16: + TypeConversionRecord( + primitive_type=xla_data_pb2.F16, + numpy_dtype=np.float16, + literal_field_name='f16s', + literal_field_type=float), + xla_data_pb2.F32: + TypeConversionRecord( + primitive_type=xla_data_pb2.F32, + numpy_dtype=np.float32, + literal_field_name='f32s', + literal_field_type=float), + xla_data_pb2.F64: + TypeConversionRecord( + primitive_type=xla_data_pb2.F64, + numpy_dtype=np.float64, + literal_field_name='f64s', + literal_field_type=float), + xla_data_pb2.S8: + TypeConversionRecord( + primitive_type=xla_data_pb2.S8, + numpy_dtype=np.int8, + literal_field_name='s8s', + literal_field_type=int), + xla_data_pb2.S16: + TypeConversionRecord( + primitive_type=xla_data_pb2.S16, + numpy_dtype=np.int16, + literal_field_name='s16s', + literal_field_type=int), + xla_data_pb2.S32: + TypeConversionRecord( + primitive_type=xla_data_pb2.S32, + numpy_dtype=np.int32, + literal_field_name='s32s', + literal_field_type=int), + xla_data_pb2.S64: + TypeConversionRecord( + primitive_type=xla_data_pb2.S64, + numpy_dtype=np.int64, + literal_field_name='s64s', + literal_field_type=int), + xla_data_pb2.U8: + TypeConversionRecord( + primitive_type=xla_data_pb2.U8, + numpy_dtype=np.uint8, + literal_field_name='s8s', + literal_field_type=int), + xla_data_pb2.U16: + TypeConversionRecord( + primitive_type=xla_data_pb2.U16, + numpy_dtype=np.uint16, + literal_field_name='s16s', + literal_field_type=int), + xla_data_pb2.U32: + TypeConversionRecord( + primitive_type=xla_data_pb2.U32, + numpy_dtype=np.uint32, + literal_field_name='s32s', + literal_field_type=int), + xla_data_pb2.U64: + TypeConversionRecord( + primitive_type=xla_data_pb2.U64, + numpy_dtype=np.uint64, + literal_field_name='s64s', + literal_field_type=int), + xla_data_pb2.PRED: + TypeConversionRecord( + primitive_type=xla_data_pb2.PRED, + numpy_dtype=np.bool, + literal_field_name='preds', + literal_field_type=bool) +} + +# Maps from Numpy dtypes to TypeConversionRecord. +# Note the conversion on the key. Numpy has a known issue wherein dtype hashing +# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, +# when keying by dtype in this dict, we use the string form of dtypes. +MAP_DTYPE_TO_RECORD = { + str(np.dtype(record.numpy_dtype)): record + for record in MAP_XLA_TYPE_TO_RECORD.values() +} diff --git a/tensorflow/compiler/xla/python_api/xla_literal.py b/tensorflow/compiler/xla/python_api/xla_literal.py new file mode 100644 index 0000000000000000000000000000000000000000..b040098c294ffaae92b72f678947f99289239314 --- /dev/null +++ b/tensorflow/compiler/xla/python_api/xla_literal.py @@ -0,0 +1,95 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ====================================== +"""XLA LiteralProto utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.compiler.xla.python_api import types +from tensorflow.compiler.xla.python_api import xla_shape + + +def ConvertLiteralToNumpyArray(literal): + """Converts a XLA literal to a Numpy array.""" + element_type = literal.shape.element_type + if element_type == xla_data_pb2.TUPLE: + return tuple( + ConvertLiteralToNumpyArray(subliteral) + for subliteral in literal.tuple_literals) + + type_record = types.MAP_XLA_TYPE_TO_RECORD[element_type] + if not literal.shape.dimensions: + return np.array( + getattr(literal, type_record.literal_field_name)[0], + type_record.numpy_dtype) + else: + # Infer the proper Numpy order from the LiteralProto's layout. The repeated + # field representing the array's content in the Literal is linearized. + # Reading is done in two steps: + # + # 1. Read the array as 1D from the LiteralProto repeated field. + # 2. Reshape the array to its proper shape, using the right order depending + # on the LiteralProto's layout. + layout_order = literal.shape.layout.minor_to_major + numpy_shape = tuple(literal.shape.dimensions) + if layout_order == range(len(literal.shape.dimensions)): + numpy_reshaper = lambda arr: arr.reshape(numpy_shape, order='F') + elif layout_order == range(len(literal.shape.dimensions) - 1, -1, -1): + numpy_reshaper = lambda arr: arr.reshape(numpy_shape, order='C') + else: + raise NotImplementedError('Unsupported layout: {0}'.format(layout_order)) + ndarray = np.array( + getattr(literal, type_record.literal_field_name), + copy=False, + dtype=type_record.numpy_dtype) + return numpy_reshaper(ndarray) + + +def _ConvertNumpyArrayToLiteral(ndarray): + """Converts a Numpy array to a XLA literal.""" + type_record = types.MAP_DTYPE_TO_RECORD[str(ndarray.dtype)] + literal = xla_data_pb2.LiteralProto() + literal.shape.CopyFrom(xla_shape.CreateShapeFromNumpy(ndarray).message) + + if ndarray.ndim == 0: + getattr(literal, type_record.literal_field_name).append( + np.asscalar(ndarray.astype(type_record.literal_field_type))) + else: + # Ndarrays with boolean dtypes need special type conversion with protobufs + if ndarray.dtype in {np.bool_, np.dtype('bool')}: + for element in np.nditer(ndarray): + getattr(literal, type_record.literal_field_name).append( + type_record.literal_field_type(element)) + else: + ndarray_flat = ndarray.ravel(order='A') + getattr(literal, type_record.literal_field_name).extend(ndarray_flat) + return literal + + +def ConvertNumpyArrayToLiteral(value): + """Converts a Numpy array or a nested tuple thereof to an XLA literal.""" + if isinstance(value, tuple): + literal = xla_data_pb2.LiteralProto() + literal.shape.CopyFrom(xla_shape.CreateShapeFromNumpy(value).message) + for component in value: + component_literal = literal.tuple_literals.add() + component_literal.CopyFrom(ConvertNumpyArrayToLiteral(component)) + return literal + else: + return _ConvertNumpyArrayToLiteral(value) diff --git a/tensorflow/compiler/xla/python_api/xla_shape.py b/tensorflow/compiler/xla/python_api/xla_shape.py new file mode 100644 index 0000000000000000000000000000000000000000..6af28958035bbb03e7e1dbb0d0c7bb2c2f25b96d --- /dev/null +++ b/tensorflow/compiler/xla/python_api/xla_shape.py @@ -0,0 +1,155 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ====================================== +"""XLA Shape utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.compiler.xla.python_api import types + + +class Shape(object): + """Wraps a xla_data_pb2.Shape message with a convenient Python type. + + Provides direct access to the underlying xla_data_pb2.Shape message in the + message attribute, along with accessor wrappers to the message's fields. + Avoid direct access to .message unless interacting directly with protobuf APIs + like CopyFrom. In other words, prefer hauling the shape around in a Shape, and + only access .message when strictly required by the protobuf API. + """ + + def __init__(self, element_type, dimensions, layout=None): + """Creates a new XLA Shape. + + Args: + element_type: element type from xla_data_pb2. + dimensions: sequence of dimensions sizes (integers), or sequence + of Shapes in the case of a tuple, i.e. when element_type is + TUPLE. + layout: optional minor_to_major sequence for layout. If not given, the + default major-to-minor layout is used. + + Raises: + ValueError: if element_type is TUPLE but dimensions are not Shape objects. + """ + self.message = xla_data_pb2.Shape() + self.message.element_type = element_type + if element_type == xla_data_pb2.TUPLE: + if not all(isinstance(subshape, Shape) for subshape in dimensions): + raise ValueError( + 'XLA tuple requires sequence of Shape objects as dimensions') + self._tuple_shapes = tuple(dimensions) + for component_shape in self._tuple_shapes: + component_message = self.message.tuple_shapes.add() + component_message.CopyFrom(component_shape.message) + else: + self.message.dimensions.extend(dimensions) + if layout is None: + layout = list(reversed(range(len(dimensions)))) + self.message.layout.format = xla_data_pb2.DENSE + self.message.layout.minor_to_major.extend(layout) + + def element_type(self): + return self.message.element_type + + def is_tuple(self): + return self.element_type() == xla_data_pb2.TUPLE + + def dimensions(self): + if self.is_tuple(): + raise ValueError('Tuple shape has no dimensions. Try tuple_shapes()?') + return self.message.dimensions + + def tuple_shapes(self): + """If this is a tuple, returns its sequence of constituent Shape objects. + + Returns: + Tuple sub-shapes. + + Raises: + ValueError: if this is not a tuple. + """ + if not self.is_tuple(): + raise ValueError('tuple_shapes() called on a non-tuple shape') + return self._tuple_shapes + + def layout(self): + return self.message.layout + + @staticmethod + def from_pyval(pyval): + return CreateShapeFromNumpy(pyval) + + +def _CreateShapeFromNumpy(ndarray): # pylint: disable=invalid-name + """Create a Shape from a given Numpy array. + + Args: + ndarray: Numpy array. + + Returns: + A Shape object. + """ + element_type = types.MAP_DTYPE_TO_RECORD[str(ndarray.dtype)].primitive_type + dimensions = ndarray.shape + + # Set the shape's layout based on the ordering of ndarray. + # Numpy arrays come in two orders: Fortran (column-major) and C (row-major). + if np.isfortran(ndarray): + # Column-major layout. This corresponds to a "dimension order is + # minor-to-major" layout in XLA. + layout = range(ndarray.ndim) + else: + # Row-major layout. This corresponds to a "dimension order is + # major-to-minor" layout int XLA. + layout = list(reversed(xrange(ndarray.ndim))) + + return Shape(element_type, dimensions, layout) + + +def CreateShapeFromNumpy(value): # pylint: disable=invalid-name + """Create a Shape from a Numpy array or a nested tuple structure thereof. + + Args: + value: Numpy array or (possibly nested) tuple structure that bottoms out in + Numpy arrays. + + Returns: + A Shape object. + """ + if isinstance(value, tuple): + return Shape( + xla_data_pb2.TUPLE, + [CreateShapeFromNumpy(component) for component in value]) + else: + return _CreateShapeFromNumpy(value) + + +def CreateShapeFromDtypeAndTuple(dtype, shape_tuple): # pylint: disable=invalid-name + """Create a shape from a Numpy dtype and a sequence of nonnegative integers. + + Args: + dtype: a numpy dtype, e.g. np.dtype('int32'). + shape_tuple: a sequence of nonnegative integers. + + Returns: + A Shape object. + """ + element_type = types.MAP_DTYPE_TO_RECORD[str(dtype)].primitive_type + return Shape(element_type, shape_tuple) diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 0d56a9a477b15964ad45e798865aa8d2c7385073..0b1cec1925d4424db086f8a3f62c91ede090189c 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -39,10 +39,10 @@ tf_cc_binary( srcs = ["grpc_service_main.cc"], deps = [ ":grpc_service", + "//tensorflow:grpc++", "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", - "@grpc//:grpc++_unsecure", ], ) @@ -54,6 +54,7 @@ tf_cc_test( ], deps = [ ":grpc_stub", + "//tensorflow:grpc++", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -61,7 +62,6 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", - "@grpc//:grpc++_unsecure", ], ) @@ -71,9 +71,9 @@ cc_library( hdrs = ["grpc_service.h"], deps = [ ":xla_service_proto", + "//tensorflow:grpc++", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core/distributed_runtime/rpc:grpc_util", - "@grpc//:grpc++_unsecure", ], ) diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc index 313f11a9a957155eb277dc02ba5d2565c87e0235..d7dd9786a2bbde2d18ae81a9a9d4cc4b2cc38411 100644 --- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include -#include "grpc++/create_channel.h" -#include "grpc++/security/credentials.h" +#include "grpcpp/create_channel.h" +#include "grpcpp/security/credentials.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" diff --git a/tensorflow/compiler/xla/rpc/grpc_service.h b/tensorflow/compiler/xla/rpc/grpc_service.h index 5cd573167ae8c002ad8f09e8ba3fb25c6f356564..ca1b09b648013ad45d806040c5ddcf11d9e5604e 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.h +++ b/tensorflow/compiler/xla/rpc/grpc_service.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_RPC_GRPC_SERVICE_H_ #define TENSORFLOW_COMPILER_XLA_RPC_GRPC_SERVICE_H_ -#include "grpc++/server_context.h" +#include "grpcpp/server_context.h" #include "tensorflow/compiler/xla/rpc/xla_service.grpc.pb.h" #include "tensorflow/compiler/xla/service/service.h" diff --git a/tensorflow/compiler/xla/rpc/grpc_service_main.cc b/tensorflow/compiler/xla/rpc/grpc_service_main.cc index e29908ccec80db76e3b5b856e57382c56430c379..c68c857c304138ff4318e243f66547c6acce1005 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service_main.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service_main.cc @@ -15,9 +15,9 @@ limitations under the License. // Basic server binary that exposes a xla::Service through a GRPC interface // on a configurable port. -#include "grpc++/security/server_credentials.h" -#include "grpc++/server.h" -#include "grpc++/server_builder.h" +#include "grpcpp/security/server_credentials.h" +#include "grpcpp/server.h" +#include "grpcpp/server_builder.h" #include "tensorflow/compiler/xla/rpc/grpc_service.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/init_main.h" diff --git a/tensorflow/compiler/xla/rpc/xla_service.proto b/tensorflow/compiler/xla/rpc/xla_service.proto index 92eb19ec0f9696974556be01a93c074846f6c23a..551ae895e05586daec0ffcd425f4950f76bdd50d 100644 --- a/tensorflow/compiler/xla/rpc/xla_service.proto +++ b/tensorflow/compiler/xla/rpc/xla_service.proto @@ -115,10 +115,6 @@ service XlaService { returns (ComputeConstantResponse) { } - // Retrieves the inferred shape for a value within a computation. - rpc GetLocalShape(GetLocalShapeRequest) returns (GetLocalShapeResponse) { - } - // Requests one or more device handles from the target. The returned device // handles can be used to specify the device on which to execute computations // or transfer data. @@ -132,18 +128,6 @@ service XlaService { returns (CreateChannelHandleResponse) { } - // Requests that the referenced computation be specialized for the provided - // arguments for subsequent execution. This permits things such as value - // specialization. - rpc Specialize(SpecializeRequest) returns (SpecializeResponse) { - } - - // Modifies the provided computation so that subsequent executions - // will compute the provided ComputationDataHandle, rather than the - // last expression enqueued on that Computation. - rpc SetReturnValue(SetReturnValueRequest) returns (SetReturnValueResponse) { - } - // Invokes the provided computation with the provided global data passed as // immutable arguments. The request contains the whole computation graph. // Returns global data output and execution timing. diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 75961d49a5ecb7ebebb213c450575fe5e66deea3..fb6281ace651bd0bfd128ef639870658f6064102 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -21,13 +21,6 @@ load( "tf_proto_library_py", ) -xla_proto_library( - name = "session_proto", - srcs = ["session.proto"], - visibility = ["//visibility:public"], - deps = ["//tensorflow/compiler/xla:xla_data_proto"], -) - xla_proto_library( name = "hlo_proto", srcs = ["hlo.proto"], @@ -276,6 +269,7 @@ cc_library( "dfs_hlo_visitor.cc", "hlo_computation.cc", "hlo_instruction.cc", + "hlo_instructions.cc", "hlo_module.cc", "hlo_opcode.cc", "hlo_sharding.cc", @@ -287,16 +281,17 @@ cc_library( "hlo_computation.h", "hlo_domain_metadata.h", "hlo_instruction.h", + "hlo_instructions.h", "hlo_module.h", "hlo_opcode.h", "hlo_sharding.h", ], deps = [ + ":hlo_casting_utils", ":hlo_module_config", ":hlo_proto", ":hlo_reachability", ":name_uniquer", - ":versioned_computation_handle", "//tensorflow/compiler/xla:array", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", @@ -405,17 +400,6 @@ tf_cc_test( ], ) -cc_library( - name = "versioned_computation_handle", - srcs = ["versioned_computation_handle.cc"], - hdrs = ["versioned_computation_handle.h"], - deps = [ - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", - ], -) - tf_cc_test( name = "hlo_instruction_test", srcs = ["hlo_instruction_test.cc"], @@ -595,7 +579,6 @@ cc_library( ":allocation_tracker", ":backend", ":channel_tracker", - ":compilation_cache", ":compiler", ":computation_layout", ":device_memory_allocator", @@ -608,10 +591,8 @@ cc_library( ":hlo_module_config", ":hlo_proto_util", ":platform_util", - ":session_proto", ":source_map_util", ":transfer_manager", - ":versioned_computation_handle", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:service_interface", @@ -646,7 +627,6 @@ cc_library( ":platform_util", ":service", ":shaped_buffer", - ":versioned_computation_handle", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", @@ -766,9 +746,7 @@ cc_library( ":hlo_graph_dumper", ":hlo_proto", ":pool", - ":session_proto", ":shaped_buffer", - ":versioned_computation_handle", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -870,8 +848,6 @@ cc_library( hdrs = ["channel_tracker.h"], deps = [ ":hlo", - ":session_proto", - ":versioned_computation_handle", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -1125,6 +1101,7 @@ tf_cc_test( srcs = ["hlo_scheduling_test.cc"], deps = [ ":buffer_value", + ":heap_simulator", ":hlo", ":hlo_ordering", ":hlo_scheduling", @@ -1172,6 +1149,19 @@ tf_cc_test( ], ) +cc_library( + name = "multi_output_fusion", + srcs = ["multi_output_fusion.cc"], + hdrs = ["multi_output_fusion.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", + ], +) + cc_library( name = "hlo_creation_utils", srcs = ["hlo_creation_utils.cc"], @@ -1653,7 +1643,6 @@ tf_cc_test( ":hlo_cost_analysis", ":local_service", ":service", - ":versioned_computation_handle", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", @@ -1994,20 +1983,6 @@ tf_cc_test( ], ) -cc_library( - name = "compilation_cache", - srcs = ["compilation_cache.cc"], - hdrs = ["compilation_cache.h"], - deps = [ - ":executable", - ":hlo_module_config", - ":versioned_computation_handle", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", - ], -) - cc_library( name = "layout_assignment", srcs = [ @@ -2149,6 +2124,7 @@ cc_library( ":buffer_liveness", ":buffer_value", ":call_graph", + ":copy_insertion", ":flatten_call_graph", ":hlo", ":hlo_dce", @@ -2156,6 +2132,7 @@ cc_library( ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", + ":tuple_simplifier", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -2169,6 +2146,7 @@ tf_cc_test( name = "hlo_rematerialization_test", srcs = ["hlo_rematerialization_test.cc"], deps = [ + ":flatten_call_graph", ":hlo", ":hlo_matchers", ":hlo_ordering", @@ -2178,6 +2156,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) @@ -2404,7 +2383,6 @@ cc_library( ":hlo_graph_dumper", ":hlo_pass", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", ], ) @@ -2421,6 +2399,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -2568,7 +2547,6 @@ cc_library( name = "hlo_tfgraph_builder", srcs = ["hlo_tfgraph_builder.cc"], hdrs = ["hlo_tfgraph_builder.h"], - visibility = ["//tensorflow/compiler/xla/tools:__pkg__"], deps = [ ":hlo", "//tensorflow/compiler/xla:literal_util", @@ -2599,6 +2577,7 @@ cc_library( hdrs = ["hlo_graph_dumper.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_execution_profile", ":hlo_tfgraph_builder", "//tensorflow/compiler/xla:literal_util", @@ -3025,13 +3004,14 @@ cc_library( cc_library( name = "hlo_casting_utils", hdrs = ["hlo_casting_utils.h"], - deps = [":hlo"], + deps = ["//tensorflow/core:lib"], ) tf_cc_test( name = "hlo_casting_utils_test", srcs = ["hlo_casting_utils_test.cc"], deps = [ + ":hlo", ":hlo_casting_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index dc5f1b31bf8510be404491b7bceb36f73f4cbf75..1fc8fb9b6994db78fe3aa06e1ea790decfce7b97 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -449,7 +449,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( // Filter out and remove empty operands. std::vector nonempty_operands; for (HloInstruction* operand : operands) { - if (!ShapeUtil::HasZeroElements(operand->shape())) { + if (!ShapeUtil::IsZeroElementArray(operand->shape())) { nonempty_operands.push_back(operand); } } @@ -1058,9 +1058,9 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { } // Replace a zero element dot with a broadcast of the constant 0. - if (ShapeUtil::HasZeroElements(dot->shape()) || - ShapeUtil::HasZeroElements(lhs->shape()) || - ShapeUtil::HasZeroElements(rhs->shape())) { + if (ShapeUtil::IsZeroElementArray(dot->shape()) || + ShapeUtil::IsZeroElementArray(lhs->shape()) || + ShapeUtil::IsZeroElementArray(rhs->shape())) { auto zero = computation_->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); return ReplaceWithNewInstruction( @@ -1392,7 +1392,7 @@ Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) { } Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { - if (ShapeUtil::HasZeroElements(pad->operand(0)->shape())) { + if (ShapeUtil::IsZeroElementArray(pad->operand(0)->shape())) { return ReplaceWithNewInstruction( pad, HloInstruction::CreateBroadcast(pad->shape(), pad->mutable_operand(1), {})); @@ -1638,7 +1638,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { // Reshape directly to empty constant if the shape contains zero-element // dimension. - if (ShapeUtil::HasZeroElements(reshape->shape())) { + if (ShapeUtil::IsZeroElementArray(reshape->shape())) { auto empty_constant = HloInstruction::CreateConstant( Literal::CreateFromShape(reshape->shape())); @@ -1739,7 +1739,7 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( // If any dimension of update is 0, elide the DynamicUpdateSlice. This // optimization becomes invalid should we later prefer to warn about out of // bound indices. - if (ShapeUtil::HasZeroElements(update->shape())) { + if (ShapeUtil::IsZeroElementArray(update->shape())) { return ReplaceInstruction(dynamic_update_slice, dynamic_update_slice->mutable_operand(0)); } @@ -1751,8 +1751,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { auto init_value = reduce->mutable_operand(1); tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); - if (ShapeUtil::HasZeroElements(arg->shape()) || - ShapeUtil::HasZeroElements(reduce->shape())) { + if (ShapeUtil::IsZeroElementArray(arg->shape()) || + ShapeUtil::IsZeroElementArray(reduce->shape())) { return ReplaceWithNewInstruction( reduce, HloInstruction::CreateBroadcast(reduce->shape(), init_value, {})); @@ -1783,6 +1783,37 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { return ReplaceWithNewInstruction( reduce, HloInstruction::CreateReshape(reduce->shape(), arg)); } + + // If a reduce feeds a reduce with the same computation and initial value, + // they can be combined into a single reduce. + if (arg->opcode() == HloOpcode::kReduce && + init_value->Identical(*arg->operand(1)) && + *function == *arg->to_apply()) { + // Create a new reduce with the combined reduction dimensions of both + // reduces. + std::vector arg_dims = arg->dimensions(); + std::sort(arg_dims.begin(), arg_dims.end()); + std::vector reduce_dims = reduce->dimensions(); + std::sort(reduce_dims.begin(), reduce_dims.end()); + // Transform reduce_dims to the same rank as the operand of the operand. + for (int64 arg_dim : arg_dims) { + for (int64& dim : reduce_dims) { + if (dim >= arg_dim) { + ++dim; + } + } + } + std::vector new_dimensions; + new_dimensions.reserve(arg->dimensions().size() + + reduce->dimensions().size()); + std::merge(arg_dims.begin(), arg_dims.end(), reduce_dims.begin(), + reduce_dims.end(), std::back_inserter(new_dimensions)); + return ReplaceWithNewInstruction( + reduce, + HloInstruction::CreateReduce(reduce->shape(), arg->mutable_operand(0), + init_value, new_dimensions, function)); + } + // A reshape that collapses multiple dimensions into a dimension being // reduced can just reduce all of those dimensions instead of doing a // collapsing reshape before a reduction. @@ -1832,7 +1863,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { Status AlgebraicSimplifierVisitor::HandleReduceWindow( HloInstruction* reduce_window) { - if (ShapeUtil::HasZeroElements(reduce_window->operand(0)->shape())) { + if (ShapeUtil::IsZeroElementArray(reduce_window->operand(0)->shape())) { return ReplaceWithNewInstruction( reduce_window, HloInstruction::CreateBroadcast(reduce_window->shape(), @@ -2028,8 +2059,8 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( HloInstruction* convolution) { auto lhs = convolution->mutable_operand(0); auto rhs = convolution->mutable_operand(1); - if (ShapeUtil::HasZeroElements(lhs->shape()) || - ShapeUtil::HasZeroElements(rhs->shape())) { + if (ShapeUtil::IsZeroElementArray(lhs->shape()) || + ShapeUtil::IsZeroElementArray(rhs->shape())) { return ReplaceWithNewInstruction( convolution, HloInstruction::CreateBroadcast( diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index cda157f9fac1639d792fb55b5a5ddac56df271aa..2605b0488cb7c6850746df94c4ab05d6b5d35de5 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -74,6 +74,44 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { EXPECT_EQ(root, param0); } +// Test that Reduce(Reduce(A)) -> Reduce(A) +TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { + HloComputation::Builder builder(TestName()); + // Create add computation. + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module().AddEmbeddedComputation(builder.Build()); + } + Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r4f32, "param")); + std::vector dims0({0}); + Shape r3f32 = ShapeUtil::MakeShape(F32, {5, 6, 7}); + HloInstruction* reduce0 = builder.AddInstruction( + HloInstruction::CreateReduce(r3f32, param, zero, dims0, add_computation)); + std::vector dims1({1, 2}); + Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); + builder.AddInstruction(HloInstruction::CreateReduce(r1f32, reduce0, zero, + dims1, add_computation)); + module().AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Reduce(param, zero)); + EXPECT_EQ(root->dimensions(), std::vector({0, 2, 3})); +} + // Test that Const + A is canonicalized to A + Const. TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -1714,7 +1752,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1759,7 +1797,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); EXPECT_TRUE(has_negative_padding(pad)); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero))); EXPECT_FALSE( @@ -1781,7 +1819,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1804,7 +1842,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1932,7 +1970,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter, window, dnums)); - auto module = CreateNewModule(); + // TODO(b/80488902): verify this module. + auto module = HloTestBase::CreateNewModule(); auto* computation = module->AddEntryComputation(b.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, @@ -2060,7 +2099,7 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2090,7 +2129,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2121,7 +2160,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2151,7 +2190,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); @@ -2184,7 +2223,7 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), @@ -2200,10 +2239,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { HloInstruction::CreateParameter(0, r0f32, "scalar_param")); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); - HloInstruction* broadcast = - builder.AddInstruction(HloInstruction::CreateBroadcast( - broadcast_shape, scalar_param, - AsInt64Slice(broadcast_shape.dimensions()))); + HloInstruction* broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(broadcast_shape, scalar_param, {})); Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3}); HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice( @@ -2219,10 +2256,10 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(scalar_param)); @@ -2237,10 +2274,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6}); - HloInstruction* broadcast = - builder.AddInstruction(HloInstruction::CreateBroadcast( - broadcast_shape, forty_two, - AsInt64Slice(broadcast_shape.dimensions()))); + HloInstruction* broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(broadcast_shape, forty_two, {})); HloInstruction* transpose = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -2259,7 +2294,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(forty_two)); @@ -2268,7 +2303,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { - auto module = CreateNewModule(); + // TODO(b/80488902): verify this module. + auto module = HloTestBase::CreateNewModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2349,7 +2385,8 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to // ReduceWindow(Convert(op), x). TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { - auto module = CreateNewModule(); + // TODO(b/80488902): verify this module. + auto module = HloTestBase::CreateNewModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2444,7 +2481,7 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(a, root); diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index 598718c72c6941a4859063ed894c45b9c620998e..ec13fadbc75e2315d1d6ef72e24a0faca0c7de40 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -58,8 +58,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { // Runs the visitor on a computation. static bool Run(HloComputation* computation, bool rewrite_training_op, - bool rewrite_inference_op, bool rewrite_grad_op, - bool use_fusion); + bool rewrite_inference_op, bool rewrite_grad_op); // Returns whether any batch norm ops were rewritten. const bool changed() const { return changed_; } @@ -70,21 +69,14 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { explicit BatchNormExpanderVisitor(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, - bool rewrite_grad_op, bool use_fusion) + bool rewrite_grad_op) : computation_(computation), rewrite_training_op_(rewrite_training_op), rewrite_inference_op_(rewrite_inference_op), - rewrite_grad_op_(rewrite_grad_op), - use_fusion_(use_fusion) {} + rewrite_grad_op_(rewrite_grad_op) {} HloComputation* GetOrCreateScalarAddComputation( PrimitiveType primitive_type) { - HloComputation** scalar_add_computation = - &scalar_add_computations_[primitive_type]; - if (*scalar_add_computation) { - return *scalar_add_computation; - } - HloComputation::Builder b("scalar_add_computation"); Shape shape = ShapeUtil::MakeShape(primitive_type, {}); auto scalar_lhs = b.AddInstruction( @@ -93,71 +85,38 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { HloInstruction::CreateParameter(1, shape, "scalar_rhs")); auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); - *scalar_add_computation = - computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); - return *scalar_add_computation; - } - - // TODO(b/80534766): Remove maps after performance issues with scalar - // broadcasts are resolved on all backends. - HloComputation* GetOrCreateScalarRsqrtComputation( - PrimitiveType primitive_type) { - HloComputation** scalar_rsqrt_computation = - &scalar_rsqrt_computations_[primitive_type]; - if (*scalar_rsqrt_computation) { - return *scalar_rsqrt_computation; - } - - HloComputation::Builder b("scalar_add_computation"); - Shape shape = ShapeUtil::MakeShape(primitive_type, {}); - auto scalar_lhs = b.AddInstruction( - HloInstruction::CreateParameter(0, shape, "scalar_lhs")); - auto scalar_rhs = b.AddInstruction(HloInstruction::CreateConvert( - shape, b.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(-0.5f))))); - auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kPower, scalar_lhs, scalar_rhs)); - *scalar_rsqrt_computation = - computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); - return *scalar_rsqrt_computation; + return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); } - std::unique_ptr Rsqrt(HloInstruction* operand) { - return HloInstruction::CreateMap( - operand->shape(), {operand}, - GetOrCreateScalarRsqrtComputation(operand->shape().element_type())); - } - - HloComputation* GetOrCreateScalarMeanComputation(PrimitiveType primitive_type, - int64 element_count) { - HloComputation** scalar_mean_computation = - &scalar_mean_computations_[std::pair( - primitive_type, element_count)]; - if (*scalar_mean_computation) { - return *scalar_mean_computation; - } - - HloComputation::Builder b("scalar_add_computation"); - Shape shape = ShapeUtil::MakeShape(primitive_type, {}); - auto scalar_lhs = b.AddInstruction( - HloInstruction::CreateParameter(0, shape, "scalar_lhs")); - auto scalar_rhs = b.AddInstruction(HloInstruction::CreateConvert( - shape, b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0( - 1.0f / static_cast(element_count)))))); - auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kMultiply, scalar_lhs, scalar_rhs)); - *scalar_mean_computation = - computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); - return *scalar_mean_computation; + std::unique_ptr Rsqrt( + HloInstruction* operand, + const std::function)>& + add_instruction) { + HloInstruction* exponent = add_instruction(HloInstruction::CreateBroadcast( + operand->shape(), + add_instruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(operand->shape().element_type(), {}), + add_instruction(HloInstruction::CreateConstant( + Literal::CreateR0(-0.5f))))), + {})); + return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kPower, + operand, exponent); } - std::unique_ptr Mean(int64 element_count, - HloInstruction* operand) { - return HloInstruction::CreateMap( - operand->shape(), {operand}, - GetOrCreateScalarMeanComputation(operand->shape().element_type(), - element_count)); + std::unique_ptr Mean( + int64 element_count, HloInstruction* operand, + const std::function)>& + add_instruction) { + HloInstruction* elem_count_recip = + add_instruction(HloInstruction::CreateBroadcast( + operand->shape(), + add_instruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(operand->shape().element_type(), {}), + add_instruction(HloInstruction::CreateConstant( + Literal::CreateR0(1.0 / element_count))))), + {})); + return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kMultiply, + operand, elem_count_recip); } // Replaces the existing HLO instruction old_instruction, with @@ -189,18 +148,9 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { bool rewrite_training_op_; bool rewrite_inference_op_; bool rewrite_grad_op_; - bool use_fusion_; // Whether rewrite has occurred. bool changed_ = false; - - // Cached computations for adding two scalars. - tensorflow::gtl::FlatMap - scalar_add_computations_; - tensorflow::gtl::FlatMap - scalar_rsqrt_computations_; - tensorflow::gtl::FlatMap, HloComputation*> - scalar_mean_computations_; }; } // namespace @@ -208,13 +158,12 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { bool BatchNormExpanderVisitor::Run(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, - bool rewrite_grad_op, bool use_fusion) { + bool rewrite_grad_op) { BatchNormExpanderVisitor visitor( computation, /*rewrite_training_op=*/rewrite_training_op, /*rewrite_inference_op=*/rewrite_inference_op, - /*rewrite_grad_op=*/rewrite_grad_op, - /*use_fusion=*/use_fusion); + /*rewrite_grad_op=*/rewrite_grad_op); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -290,28 +239,14 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( feature_shape, operand_squared, zero, dimensions_without_feature, add_reduce_computation)); - // Fuse two parallel reduces together to improve performance. - if (use_fusion_ && !batch_norm->has_sharding()) { - auto tuple = add(HloInstruction::CreateTuple({sum, squared_sum})); - - auto fused = computation_->CreateFusionInstruction( - {tuple, sum, squared_sum, operand_squared}, - HloInstruction::FusionKind::kInput); - - sum = add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - - squared_sum = - add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); - } - // E[X]. - auto mean = add(Mean(elements_per_feature_int64, sum)); + auto mean = add(Mean(elements_per_feature_int64, sum, add)); auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); // E[X^2]. - auto square_mean = add(Mean(elements_per_feature_int64, squared_sum)); + auto square_mean = add(Mean(elements_per_feature_int64, squared_sum, add)); // E^2[X]. auto mean_square = @@ -329,7 +264,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon)); + auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add)); // X - E[X]. auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract, @@ -431,7 +366,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon)); + auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add)); // X - E[X]. auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract, @@ -545,10 +480,12 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( // rsqrt[Var[X] + epsilon]. auto rsqrt_var_add_epsilon_broadcasted = add(Rsqrt(add_binary(activation_shape, HloOpcode::kAdd, - variance_broadcasted, epsilon_activation))); + variance_broadcasted, epsilon_activation), + add)); auto rsqrt_var_add_epsilon = add(Rsqrt( - add_binary(feature_shape, HloOpcode::kAdd, variance, epsilon_feature))); + add_binary(feature_shape, HloOpcode::kAdd, variance, epsilon_feature), + add)); // X - E[X]. auto activation_minus_mean = add_binary( @@ -573,21 +510,6 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( feature_shape, grad_output, zero, dimensions_without_feature, add_reduce_computation)); - if (use_fusion_ && !batch_norm->has_sharding()) { - auto tuple = add(HloInstruction::CreateTuple( - {sum_grad_output_times_activiation_minus_mean, grad_beta})); - - auto fused = computation_->CreateFusionInstruction( - {tuple, sum_grad_output_times_activiation_minus_mean, grad_beta}, - HloInstruction::FusionKind::kInput); - - sum_grad_output_times_activiation_minus_mean = - add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - - grad_beta = - add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); - } - // Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]). auto grad_scale = add_binary(feature_shape, HloOpcode::kMultiply, sum_grad_output_times_activiation_minus_mean, @@ -616,8 +538,8 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( add_binary(activation_shape, HloOpcode::kMultiply, scale_broadcasted, rsqrt_var_add_epsilon_broadcasted); - scale_times_rsqrt_var_add_epsilon = - add(Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon)); + scale_times_rsqrt_var_add_epsilon = add( + Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon, add)); auto elements_per_feature_literal = Literal::CreateR0(elements_per_feature_int64); @@ -665,8 +587,8 @@ StatusOr BatchNormExpander::Run(HloModule* module) { bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { if (BatchNormExpanderVisitor::Run(comp, rewrite_training_op_, - rewrite_inference_op_, rewrite_grad_op_, - use_fusion_)) { + rewrite_inference_op_, + rewrite_grad_op_)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h index 4ad987085da91684bb7891070afeefd19be4138f..7ae202c583516443a6263403fb5460d1adbabd97 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.h +++ b/tensorflow/compiler/xla/service/batchnorm_expander.h @@ -31,11 +31,10 @@ class BatchNormExpander : public HloPassInterface { // When use_fusion is set, a multi-output fusion node is created. BatchNormExpander(bool rewrite_training_op = false, bool rewrite_inference_op = false, - bool rewrite_grad_op = false, bool use_fusion = true) + bool rewrite_grad_op = false) : rewrite_training_op_(rewrite_training_op), rewrite_inference_op_(rewrite_inference_op), - rewrite_grad_op_(rewrite_grad_op), - use_fusion_(use_fusion) {} + rewrite_grad_op_(rewrite_grad_op) {} ~BatchNormExpander() = default; tensorflow::StringPiece name() const override { return "batchnorm_expander"; } @@ -47,7 +46,6 @@ class BatchNormExpander : public HloPassInterface { bool rewrite_training_op_; bool rewrite_inference_op_; bool rewrite_grad_op_; - bool use_fusion_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index 28e71c2054f59ba4d5d096bf7d898161877bb42f..f7b4c1405dbc8719d8fba5476e6e41d2921ea877 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -211,6 +211,17 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { auto builder = HloComputation::Builder(TestName()); + + auto module = CreateNewModule(); + HloComputation::Builder sum_builder("add"); + auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x")); + auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y")); + sum_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y)); + HloComputation* sum = module->AddEmbeddedComputation(sum_builder.Build()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); @@ -223,7 +234,8 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( - ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b})); + ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, + sum, /*replica_group_ids=*/{}, /*barrier=*/"")); HloInstruction* gte_a = builder.AddInstruction( HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); HloInstruction* gte_b = builder.AddInstruction( @@ -233,7 +245,6 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* tuple = builder.AddInstruction( HloInstruction::CreateTuple({gte_a, convert_gte_b})); - auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_TRUE(FoldConversions(module.get())); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 1afaefd9df9c5771fb9e134ae9050f3abb00ea4a..830f26422bdc2b3bd789e7d5926bcebac815d34a 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -228,6 +228,17 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { } TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { + auto module = CreateNewModule(); + HloComputation::Builder sum_builder("sum"); + auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x")); + auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y")); + sum_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y)); + HloComputation* reduction = + module->AddEmbeddedComputation(sum_builder.Build()); + auto builder = HloComputation::Builder(TestName()); Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); @@ -239,11 +250,11 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( - ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b})); + ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction, + /*replica_group_ids=*/{}, /*barrier=*/"")); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); - auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_TRUE(Normalize(module.get())); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index ed0746980f87ac2bea79c308644dc63769f9e309..ee6b6f69b96216403c48933e424ebbfecd482eee 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -204,6 +204,12 @@ void BFloat16Propagation::DetermineWhileComputationsPrecision( bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, const ShapeIndex& index) const { + // If the subshape isn't floating point then none of the users will be BF16. + const Shape& subshape = ShapeUtil::GetSubshape(hlo.shape(), index); + if (subshape.element_type() != BF16 && subshape.element_type() != F32) { + return false; + } + auto& value_set = dataflow_->GetValueSet(&hlo, index); for (const HloValue* value : value_set.values()) { if (ContainsKey(values_that_must_be_kept_as_f32_, value)) { @@ -257,23 +263,34 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, // If the op propagates precision and it outputs a BF16, then it's OK to // supply BF16 also as the input. In the backward pass, the users shapes // should have already been processed. - PrimitiveType user_output_type = PRIMITIVE_TYPE_INVALID; - if (use.instruction->opcode() == HloOpcode::kTuple || - (use.instruction->opcode() == HloOpcode::kCrossReplicaSum && - ShapeUtil::IsTuple(use.instruction->shape()))) { - ShapeIndex use_output_index{use.operand_number}; - for (int64 i : use.operand_index) { - use_output_index.push_back(i); - } - user_output_type = - OutputTypeAfterChange(use.instruction, use_output_index); - } else { - user_output_type = OutputTypeAfterChange(use.instruction, {}); - } if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision( - *use.instruction, use.operand_number) && - user_output_type == BF16) { - continue; + *use.instruction, use.operand_number)) { + if (use.instruction->opcode() == HloOpcode::kTuple || + (use.instruction->opcode() == HloOpcode::kCrossReplicaSum && + ShapeUtil::IsTuple(use.instruction->shape()))) { + ShapeIndex use_output_index{use.operand_number}; + for (int64 i : use.operand_index) { + use_output_index.push_back(i); + } + if (OutputTypeAfterChange(use.instruction, use_output_index) == + BF16) { + continue; + } + } else if (use.instruction->opcode() == HloOpcode::kGetTupleElement) { + ShapeIndex use_output_index; + for (int64 i = 1; i < use.operand_index.size(); ++i) { + use_output_index.push_back(use.operand_index[i]); + } + if (OutputTypeAfterChange(use.instruction, use_output_index) == + BF16) { + continue; + } + } else { + if (OutputTypeAfterChange(use.instruction, use.operand_index) == + BF16) { + continue; + } + } } return false; } @@ -368,6 +385,7 @@ bool BFloat16Propagation::InstructionIsCandidateForBF16Output( if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) && hlo->opcode() != HloOpcode::kTuple && hlo->opcode() != HloOpcode::kGetTupleElement && + hlo->opcode() != HloOpcode::kDomain && hlo->shape().element_type() != BF16) { for (int64 i = 0; i < hlo->operand_count(); ++i) { if (!bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo, @@ -559,7 +577,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( HloModule* module) { - std::list computations_topological_order = + const auto& computations_topological_order = module->MakeComputationPostOrder(); tensorflow::gtl::FlatSet resolved; for (auto comp_it = computations_topological_order.rbegin(); @@ -631,7 +649,7 @@ Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) { subshape, converted_outputs.element(parent_index), output_index.back())); } - if (ShapeUtil::IsTuple(subshape)) { + if (!ShapeUtil::IsArray(subshape)) { continue; } if (!ShapeUtil::Compatible( @@ -742,7 +760,7 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module)); - std::list computations_topological_order = + const auto& computations_topological_order = module->MakeComputationPostOrder(); // The first step is a forward pass (parameters to root), where we determine // the potential candidate instructions to use bfloat16 in the outputs that diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 5e1499ee6b6ef397f95f7ed29e808d530777bd07..e2ca689c0649528231c0581a37c145c328652420 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -150,11 +150,11 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant); EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant); EXPECT_TRUE(LiteralTestUtil::Equal( - dot->operand(0)->literal(), - *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_a)))); + *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_a)), + dot->operand(0)->literal())); EXPECT_TRUE(LiteralTestUtil::Equal( - dot->operand(1)->literal(), - *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_b)))); + *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_b)), + dot->operand(1)->literal())); } // Tests that BF16 can be propagated through nested tuples. @@ -742,4 +742,43 @@ TEST_F(BFloat16PropagationTest, NoopConversionRemoved) { EXPECT_EQ(add1->shape().element_type(), BF16); } +TEST_F(BFloat16PropagationTest, TupleDomain) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* a_trans = + builder.AddInstruction(HloInstruction::CreateTranspose(shape, a, {0, 1})); + HloInstruction* b_trans = + builder.AddInstruction(HloInstruction::CreateTranspose(shape, b, {0, 1})); + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({a_trans, b_trans})); + HloInstruction* domain = builder.AddInstruction( + HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr)); + HloInstruction* a_gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, domain, 0)); + HloInstruction* b_gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, domain, 1)); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_gte, b_gte)); + HloInstruction* root = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), root); + EXPECT_TRUE(OutputsBF16(a_trans)); + EXPECT_TRUE(OutputsBF16(b_trans)); + EXPECT_TRUE(OutputsBF16(a_gte)); + EXPECT_TRUE(OutputsBF16(b_gte)); + EXPECT_FALSE(OutputsBF16(a)); + EXPECT_FALSE(OutputsBF16(b)); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc index 07b4b14b5ec1bdbc01345091105df69368b0b2fb..8595afca7e735528d9ef29a323696c0661fe971c 100644 --- a/tensorflow/compiler/xla/service/bfloat16_support.cc +++ b/tensorflow/compiler/xla/service/bfloat16_support.cc @@ -25,6 +25,7 @@ bool BFloat16Support::SupportsBF16Operand(const HloInstruction& hlo, case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kCustomCall: + case HloOpcode::kDomain: case HloOpcode::kGetTupleElement: case HloOpcode::kTuple: case HloOpcode::kWhile: @@ -43,6 +44,7 @@ bool BFloat16Support::SupportsBF16Output(const HloInstruction& hlo) const { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kCustomCall: + case HloOpcode::kDomain: case HloOpcode::kGetTupleElement: case HloOpcode::kTuple: case HloOpcode::kWhile: @@ -81,6 +83,7 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( case HloOpcode::kConcatenate: case HloOpcode::kConvert: case HloOpcode::kCopy: + case HloOpcode::kDomain: case HloOpcode::kGetTupleElement: case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -92,6 +95,9 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( case HloOpcode::kTranspose: case HloOpcode::kTuple: return true; + case HloOpcode::kBitcast: + return hlo.shape().element_type() == + hlo.operand(0)->shape().element_type(); case HloOpcode::kDynamicSlice: return operand_index == 0; case HloOpcode::kDynamicUpdateSlice: diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index c0b8bf903923a327fb1378eafb51a7d493d5e62d..afe4b2e1425f9e84320ffd5f08beceaac8168c22 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -135,6 +135,7 @@ Status GatherComputationsByAllocationType( worklist.push_back(std::make_pair(subcomputation, false)); // Not thread local. break; + case HloOpcode::kCrossReplicaSum: case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: @@ -632,7 +633,7 @@ Status BufferAssignment::ComputeSummaryStats() { if (module_sequence.size() == module_->computation_count()) { TF_ASSIGN_OR_RETURN( const int64 min_size, - MinimumMemoryForSequence(module_sequence, buffer_size_)); + HeapSimulator::MinimumMemoryForModule(module_sequence, buffer_size_)); stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size; } diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 7e86c33687e595ad154361dd7018791299cc56ab..efa4696130ffeff669b0d674438a45c5a9d48ef2 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -371,11 +371,11 @@ TEST_F(BufferAssignmentTest, Basic) { // param1[100] --------------/--------/ auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -418,11 +418,11 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) { // share anything. auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -477,11 +477,11 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) { // have the color 0, which allows the mul and add to share buffers. auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -547,11 +547,11 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) { // auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -601,7 +601,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) { // Creates the main kernel and verifies instruction counts. auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32a100x10_, "")); + HloInstruction::CreateParameter(0, f32a100x10_, "p")); auto map = builder.AddInstruction( HloInstruction::CreateMap(f32a100x10_, {param0}, map_computation)); module->AddEntryComputation(builder.Build()); @@ -654,7 +654,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32a100x10_, "")); + HloInstruction::CreateParameter(0, f32a100x10_, "p")); auto exp1 = builder.AddInstruction( HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, param0)); auto exp2 = builder.AddInstruction( @@ -818,7 +818,7 @@ TEST_F(BufferAssignmentTest, UnaryOpReuseChain) { // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg) auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32vec100_, "")); + HloInstruction::CreateParameter(0, f32vec100_, "p")); auto exp1 = builder.AddInstruction( HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, param0)); auto tanh = builder.AddInstruction( @@ -1496,11 +1496,11 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) { // param1[100] --------------/--------/ auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -1536,7 +1536,7 @@ TEST_F(BufferAssignmentTest, PeakBuffers) { // be {%rev, %neg, %concat}. This occurs right at the concat itself. auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32vec100_, "")); + HloInstruction::CreateParameter(0, f32vec100_, "p")); auto log = builder.AddInstruction( HloInstruction::CreateUnary(f32vec100_, HloOpcode::kLog, param)); auto rev = builder.AddInstruction( @@ -1673,7 +1673,7 @@ class WhileBufferAssignmentTest : public HloTestBase { std::unique_ptr RunBufferAssignment(HloModule* module, int64 alignment = 1) { auto sequence = - CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); + ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( module, xla::MakeUnique(module, sequence), ByteSizeOf, @@ -2103,7 +2103,7 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { RunCopyInsertion(module.get()); auto sequence = - CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); + ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); // To trigger b/38494731, we want a specific Hlo sequence for the // root computation, so we overwrite that entry with a manually diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index a8053d15e124319c5c898f0034b9aaa95a007a89..a23427f00ccd88bb0fe1d973a667f80ca54b14cd 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -57,6 +57,7 @@ CallContext GetInstructionCallContext(HloOpcode opcode) { case HloOpcode::kConditional: case HloOpcode::kWhile: return CallContext::kSequential; + case HloOpcode::kCrossReplicaSum: case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h index e415fb27e6b9780eb22df9e46d30ca8999868f6a..fac0afd672ff3ed083aacf778dd9c4f90a2ee870 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.h +++ b/tensorflow/compiler/xla/service/channel_tracker.h @@ -19,8 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/compilation_cache.cc b/tensorflow/compiler/xla/service/compilation_cache.cc deleted file mode 100644 index b16907da9e9c909d2639f83895db27d724a84a7b..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/compilation_cache.cc +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/compilation_cache.h" - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/logging.h" - -namespace xla { - -std::shared_ptr CompilationCache::Insert( - std::unique_ptr executable, - const HloModuleConfig& module_config) { - tensorflow::mutex_lock lock(mutex_); - - CacheKey key = - BuildKey(executable->entry_computation_handle(), module_config); - VLOG(2) << "inserting cache key: " << key; - if (cache_.count(key) == 0) { - cache_.emplace(key, std::move(executable)); - } else { - // Executable already exists in the cache. This can happen if two Execute - // calls for a new computation are received simultaneously by the - // service. In this case, we discard the Executable given as a parameter and - // return what is in the cache. This is necessary because the service relies - // on the cache to keep ownership of the Executable. We only want to store - // one Executable for a given computation version and we can't discard the - // executable which is in the cache because it may be in use. - executable.reset(); - } - return cache_.at(key); -} - -std::shared_ptr CompilationCache::LookUp( - const VersionedComputationHandle& versioned_handle, - const HloModuleConfig& module_config) const { - tensorflow::mutex_lock lock(mutex_); - - CacheKey key = BuildKey(versioned_handle, module_config); - VLOG(2) << "looking up cache key: " << key; - if (cache_.count(key) == 0) { - VLOG(2) << "cache key not found: " << key; - return nullptr; - } else { - std::shared_ptr result = cache_.at(key); - VLOG(2) << "hit executable with module config: " - << result->module_config().compilation_cache_key(); - return result; - } -} - -CompilationCache::CacheKey CompilationCache::BuildKey( - const VersionedComputationHandle& versioned_handle, - const HloModuleConfig& module_config) const { - // The computation shape is represented entirely by its ProgramShape member, - // so just serialize the proto as part of the key. - return tensorflow::strings::StrCat(versioned_handle.handle.handle(), "::", - versioned_handle.version, "::", - module_config.compilation_cache_key()); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/compilation_cache.h b/tensorflow/compiler/xla/service/compilation_cache.h deleted file mode 100644 index 09989726ae6629aa65cb1dd84c16408a75019fa5..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/compilation_cache.h +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ - -#include -#include -#include - -#include "tensorflow/compiler/xla/service/executable.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" - -namespace xla { - -// A cache which stores Executables indexed by computation handle and version. -class CompilationCache { - public: - CompilationCache() {} - - // Insert the given Executable into the cache. Return a bare Executable - // pointer for the caller to use. Note: the returned pointer will *not* be the - // same as the given unique pointer if the computation already exists in the - // cache. See comments in the .cc implementation for details of this case. - // - // module_config is provided by the caller, instead of being taken from the - // executable, so that we can insert keys into the compilation cache that are - // devoid of layout (where XLA gets to choose what layout to compile). - // - // A shared_ptr is returned so the caller can keep the Executable from being - // destructed in the event that the Executable is evicted from the - // computation cache (and the cache's shared_ptr to the Executable is - // destructed). - std::shared_ptr Insert(std::unique_ptr executable, - const HloModuleConfig& module_config); - - // Lookup the Executable for the specified versioned computation in the cache. - // Return a shared_ptr to the Executable if it exists in the cache. Return - // nullptr otherwise. - std::shared_ptr LookUp( - const VersionedComputationHandle& versioned_handle, - const HloModuleConfig& module_config) const; - - protected: - mutable tensorflow::mutex mutex_; - - // Map from versioned handle with program layout to Executable built - // for that computation version and program layout. - using CacheKey = string; - - CacheKey BuildKey(const VersionedComputationHandle& versioned_handle, - const HloModuleConfig& module_config) const; - std::map> cache_ GUARDED_BY(mutex_); - - private: - TF_DISALLOW_COPY_AND_ASSIGN(CompilationCache); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index d8fdccf9bbf1c1788bb4000aa702292362446503..7426672a7a2a9102bd5ea98bd51092982e1e09b4 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -63,7 +63,8 @@ CompileOnlyService::CompileOnlyService(const ServiceOptions& options, StatusOr>> CompileOnlyService::CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options) { + const AotCompilationOptions& options, + std::unique_ptr* metadata) { std::vector> hlo_modules; for (const AotXlaComputationInstance& instance : computations) { TF_RET_CHECK(instance.computation.has_program_shape()); @@ -100,7 +101,8 @@ CompileOnlyService::CompileAheadOfTime( hlo_modules.push_back(std::move(hlo_module)); } - return compiler_->CompileAheadOfTime(std::move(hlo_modules), options); + return compiler_->CompileAheadOfTime(std::move(hlo_modules), options, + metadata); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h index e6a66c202d6e0df3cb6d165e51beb25abd8ec45c..1ac950bdd66bd034dfdafa8598ec506221e99c2f 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.h +++ b/tensorflow/compiler/xla/service/compile_only_service.h @@ -53,6 +53,12 @@ class CompileOnlyService : public Service { const tensorflow::gtl::ArraySlice computations, const AotCompilationOptions& options); + StatusOr>> + CompileAheadOfTime( + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& options, + std::unique_ptr* metadata); + Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) override { return Unimplemented("CompileOnlyService does not support devices."); diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 6f06bba6798bdff51f10d8fe9dc524d8064ba849..6b3b9820f09803c8a04504e6c35c22de51abf04b 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -35,6 +35,27 @@ Compiler::ComputeBackendConfigs(const HloInstruction& hlo, return {}; } +std::unique_ptr +Compiler::ComputeDefaultBackendConfig(const HloInstruction& hlo, + se::StreamExecutor* executor) const { + CHECK(executor != nullptr); + return nullptr; +} + +// Define a default version where metadata is not used. +StatusOr>> +Compiler::CompileAheadOfTime( + std::vector> modules, + const AotCompilationOptions& options, + std::unique_ptr* metadata) { + if (metadata != nullptr) { + return Unimplemented( + "Populating AotCompilationMetadata is not implemented on this " + "compiler."); + } + return CompileAheadOfTime(std::move(modules), options); +} + /* static */ std::map* Compiler::GetPlatformCompilerFactories() { static auto* r = new std::map; diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 6c52ffd800d19de83877341d41ef81eee2de7251..99abb9bae32b35652e84cddc7c38dbd97ecb5006 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -94,6 +94,19 @@ class AotCompilationOptions { DebugOptions debug_options_; }; +// Abstract superclass describing metadata produced during ahead-of-time +// compilation. +class AotCompilationMetadata { + public: + AotCompilationMetadata(const AotCompilationMetadata&) = delete; + AotCompilationMetadata& operator=(AotCompilationMetadata const&) = delete; + + virtual ~AotCompilationMetadata() = default; + + protected: + AotCompilationMetadata() = default; +}; + // Abstract compiler interface that is subclassed for compilation on a // particular platform. // @@ -166,12 +179,29 @@ class Compiler { ComputeBackendConfigs(const HloInstruction& hlo, se::StreamExecutor* executor) const; + // Returns the backend configuration that the backend chooses by default for + // the given HLO. Returns no configuration if the backend does not support + // configurations for the given HLO. + // + // The stream executor is passed in to provide information about the hardware + // that the backend configurations would be targeting. + virtual std::unique_ptr + ComputeDefaultBackendConfig(const HloInstruction& hlo, + se::StreamExecutor* executor) const; + // Compiles the HLO module for ahead-of-time execution. This is intended for // use in static compilation. virtual StatusOr>> CompileAheadOfTime(std::vector> modules, const AotCompilationOptions& options) = 0; + // Similar to CompileAheadOfTime above but AotCompilationMetadata + // has an argument that can be populated during compilation. + virtual StatusOr>> + CompileAheadOfTime(std::vector> modules, + const AotCompilationOptions& options, + std::unique_ptr* metadata); + ///// // The Compiler class also serves as a point to register compiler objects // for the various platforms. diff --git a/tensorflow/compiler/xla/service/computation_layout.h b/tensorflow/compiler/xla/service/computation_layout.h index 53c3a3f7b738687db3098acfaef1ae87860d0440..6975f387b4864bf28ea0ad23d7d4602b5b346e08 100644 --- a/tensorflow/compiler/xla/service/computation_layout.h +++ b/tensorflow/compiler/xla/service/computation_layout.h @@ -32,12 +32,21 @@ namespace xla { // mutable layouts. class ComputationLayout { public: + // Creates a new ComputationLayout with the given result layout. + explicit ComputationLayout(ShapeLayout result_layout) + : result_layout_(std::move(result_layout)) {} + // Constructs a ComputationLayout from a ProgramShape. The layouts of the // parameters and results are set to the default layout. Layouts in the // ProgramShape are ignored if ignore_layouts is true. explicit ComputationLayout(const ProgramShape& program_shape, bool ignore_layouts = true); + // Adds a new parameter layout to the computation layout. + void add_parameter_layout(ShapeLayout shape_layout) { + parameter_layouts_.push_back(std::move(shape_layout)); + } + // Returns the layout of a particular parameter. const ShapeLayout& parameter_layout(int64 param_no) const { return parameter_layouts_[param_no]; diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 33d8338809d4e8c7c4774f062c3dda5494543ca6..e0ce2e3555e7746d6df212123fe1f968937cceed 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -472,6 +472,10 @@ class CopyRemover { // between copies added around aliased operations (kWhile) guarantees // this strict order. for (const HloValue* value_a : buffer.values()) { + if (ShapeUtil::IsToken(value_a->shape())) { + // Token values have no representation and cannot interfere. + continue; + } for (const HloValue* value_b : buffer.values()) { if (value_a != value_b) { DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b, @@ -613,7 +617,10 @@ class CopyRemover { VLOG(2) << copy->name() << " is not removable"; return false; } - + if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) { + VLOG(2) << copy->name() << " is not removable (shape mismatch)"; + return false; + } const CopyNodes& copy_node = copy_map_.at(copy); ValueNode* src = copy_node.src; ValueNode* dest = copy_node.dest; @@ -947,28 +954,6 @@ class CopyRemover { BufferValueTracker buffer_value_tracker_; }; -// Try to remove as many copies from the module as possible without introducing -// live range interference. Copy instructions (identified by their unique id) in -// the set copies_to_exclude are not considered for removal. -Status RemoveUnnecessaryCopies( - const HloOrdering& ordering, - const tensorflow::gtl::FlatSet& copies_to_exclude, HloModule* module) { - TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, - HloAliasAnalysis::Run(module)); - CopyRemover copy_remover(*alias_analysis, ordering, module); - XLA_VLOG_LINES(3, copy_remover.ToString()); - - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy && - !ContainsKey(copies_to_exclude, instruction->unique_id())) { - TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); - } - } - } - return Status::OK(); -} - // Add copies to address special constraints on the roots of computations not // related to live range interference: // @@ -1065,13 +1050,23 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { HloInstruction* instruction = pair.first; const ShapeTree& indices_to_copy = pair.second; + ShapeTree copies_added(indices_to_copy.shape()); std::vector users = instruction->users(); TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, instruction->parent()->DeepCopyInstruction( - instruction, &indices_to_copy)); + instruction, &indices_to_copy, &copies_added)); for (HloInstruction* user : users) { TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy)); } + // Special case copies are not eligible for later copy elision passes. + indices_to_copy.ForEachElement([&](const ShapeIndex& index, bool has_copy) { + if (has_copy) { + HloInstruction* copy = *copies_added.mutable_element(index); + if (copy != nullptr) { + copy->SetCopyElisionAllowed(false); + } + } + }); if (instruction == instruction->parent()->root_instruction()) { instruction->parent()->set_root_instruction(deep_copy); } @@ -1097,6 +1092,31 @@ void MaybeDumpModule(const string& message, const HloModule& module) { } // namespace +Status RemoveUnnecessaryCopies( + const HloOrdering& ordering, + const tensorflow::gtl::FlatSet& copies_to_exclude, HloModule* module) { + MaybeDumpModule("after adding copies to resolve interference", *module); + + TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module)); + CopyRemover copy_remover(*alias_analysis, ordering, module); + XLA_VLOG_LINES(3, copy_remover.ToString()); + + std::unique_ptr call_graph = CallGraph::Build(module); + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy && + !ContainsKey(copies_to_exclude, instruction->unique_id()) && + instruction->CopyElisionAllowed()) { + TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); + } + } + } + MaybeDumpModule("after removing unnecessary copies", *module); + + return Status::OK(); +} + StatusOr CopyInsertion::Run(HloModule* module) { // Copy insertion is performed in three steps: // @@ -1158,14 +1178,10 @@ StatusOr CopyInsertion::Run(HloModule* module) { TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); - MaybeDumpModule("after adding copies to resolve interference", *module); - DependencyHloOrdering ordering(module); TF_RETURN_IF_ERROR( RemoveUnnecessaryCopies(ordering, existing_copies, module)); - MaybeDumpModule("after removing unnecessary copies", *module); - TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module)); MaybeDumpModule("after adding special-case copies", *module); diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index 65e3d31e347e2cb249a072e7d06ca10c55401748..0d7b3c20f982cae21e5160fe5be20c85bf940ed7 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -64,6 +64,13 @@ class CopyInsertion : public HloPassInterface { static StatusOr AddCopiesForBufferAssignment(HloModule* module); }; +// Try to remove as many copies from the module as possible without introducing +// live range interference. Copy instructions (identified by their unique id) in +// the set copies_to_exclude are not considered for removal. +Status RemoveUnnecessaryCopies( + const HloOrdering& ordering, + const tensorflow::gtl::FlatSet& copies_to_exclude, HloModule* module); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_COPY_INSERTION_H_ diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 153f062d015e49db11c4c9ae0a2a61e76c020f02..ed1a50f516ee23e0f034bf5c2ed15fac7a70c3cc 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -1595,6 +1595,45 @@ TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) { EXPECT_THAT(condition->root_instruction(), op::Constant()); } +TEST_F(CopyInsertionTest, TokensShouldNotBeCopied) { + string module_string = R"( +HloModule TokensShouldNotBeCopied + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %generate-token = token[] generate-token(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %generate-token) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %TokensShouldNotBeCopied () -> s32[] { + %one = s32[] constant(1) + %negative_one = s32[] negate(%one) + %init_token = token[] generate-token() + %init_tuple = (s32[], token[]) tuple(s32[] %negative_one, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + HloRunner::CreateModuleFromString( + module_string, GetDebugOptionsForTest())); + InsertCopies(module.get()); + + // There should be no copies added because tokens should not be copied. + EXPECT_EQ(CountCopies(*module), 0); +} + std::unique_ptr MakeTrivialCondition(const Shape& shape) { auto builder = HloComputation::Builder("trivial_condition"); builder.AddInstruction( @@ -1636,8 +1675,7 @@ void BM_SequentialWhiles(int num_iters, int num_whiles) { for (int i = 0; i < num_iters; ++i) { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - HloModule module("BM_SequentialWhiles", VersionedComputationHandle(), - config); + HloModule module("BM_SequentialWhiles", config); auto builder = HloComputation::Builder("BM_SequentialWhiles"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1677,8 +1715,7 @@ void BM_ParallelWhiles(int num_iters, int num_whiles) { for (int i = 0; i < num_iters; ++i) { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - HloModule module("BM_SequentialWhiles", VersionedComputationHandle(), - config); + HloModule module("BM_SequentialWhiles", config); auto builder = HloComputation::Builder("BM_ParallelWhiles"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1750,8 +1787,7 @@ void BM_ManyElementTuple(int num_iters, const int num_tuple_inputs) { std::vector tuple_params(num_tuple_inputs); for (int i = 0; i < num_iters; ++i) { auto builder = HloComputation::Builder("BM_ParallelWhiles"); - HloModule module("BM_ManyElementTuple", VersionedComputationHandle(), - config); + HloModule module("BM_ManyElementTuple", config); for (int j = 0; j < num_tuple_inputs; ++j) { tuple_params[j] = builder.AddInstruction( HloInstruction::CreateParameter(j, element_shape, "")); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 278bb1bebfa1a0d76d0268b6b6fcfa87410ceee5..b703be0f39e2032bc58479f0b957f9d8b01a77c3 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -151,7 +151,14 @@ cc_library( "@llvm//:target", # fixdeps: keep "@llvm//:x86_code_gen", # fixdeps: keep "@llvm//:x86_disassembler", # fixdeps: keep - ], + ] + select({ + "@org_tensorflow//tensorflow:linux_ppc64le": [ + "@llvm//:powerpc_disassembler", + "@llvm//:powerpc_code_gen", + ], + "//conditions:default": [ + ], + }), alwayslink = True, # Contains compiler registration ) @@ -898,6 +905,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", "@llvm//:core", "@llvm//:support", ], diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 25b18eff20f901fc34343a12bfbd353ecec49cfb..52da9d6eac7e92188774107dd054396ebd9cd8db 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -264,8 +264,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pass.AddPass( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, - /*rewrite_grad_op=*/true, - /*use_fusion=*/false); + /*rewrite_grad_op=*/true); pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }, @@ -304,8 +303,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, ReducePrecisionInsertion::PassTiming::AFTER_FUSION); pipeline.AddPass( - module->mutable_device_entry_computation_layout(), - &target_machine_features); + module->mutable_entry_computation_layout(), &target_machine_features); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. pipeline.AddPass>( @@ -550,8 +548,8 @@ StatusOr> CpuCompiler::RunBackend( // and reduced memory usage (as compared to using DependencyHloOrdering). TF_ASSIGN_OR_RETURN( SequentialHloOrdering::HloModuleSequence module_sequence, - CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction(), - DFSMemoryScheduler)); + ScheduleComputationsInModule(*module, BufferSizeBytesFunction(), + DFSMemoryScheduler)); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. @@ -730,7 +728,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, TF_ASSIGN_OR_RETURN( SequentialHloOrdering::HloModuleSequence module_sequence, - CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction())); + ScheduleComputationsInModule(*module, BufferSizeBytesFunction())); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index cf43b74c699ca8cbbef11a0abbaf4d69476f5d77..1093559892ddb9c238fd9c1f7e3d419ec7022776 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -206,8 +206,8 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( tensorflow::gtl::MutableArraySlice buffers) { se::Stream* stream = run_options->stream(); ScopedShapedBuffer result_buffer( - /*on_host_shape=*/host_result_shape(), - /*on_device_shape=*/host_result_shape(), run_options->allocator(), + /*on_host_shape=*/result_shape(), + /*on_device_shape=*/result_shape(), run_options->allocator(), stream->parent()->device_ordinal()); // Move OwningDeviceMemory values which contain the array(s) of the result diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index e75fcb6bc9719f7453d5f0cb52d1673cef1fd3df..3ed7876715f64191f6e652d2b5cb1673df9a1b94 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace { @@ -24,6 +25,7 @@ const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce"; const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; const char* const kXlaEnableExperimentalLlvmIrGemm = "xla_enable_experimental_llvm_ir_gemm"; +const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size"; } // namespace @@ -62,6 +64,43 @@ bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { return extra_options_map.count(kXlaEnableExperimentalLlvmIrGemm) > 0; } +static tensorflow::StringPiece RemoveSuffix(tensorflow::StringPiece str, + tensorflow::StringPiece suffix) { + CHECK_GE(str.size(), suffix.size()); + CHECK_EQ(str.substr(str.size() - suffix.size()), suffix); + return str.substr(0, str.size() - suffix.size()); +} + +tensorflow::gtl::optional> LlvmIrGemmTileSize( + const HloModuleConfig& config) { + const auto& extra_options_map = + config.debug_options().xla_backend_extra_options(); + auto it = extra_options_map.find(kLlvmIrGemmTileSize); + if (it == extra_options_map.end()) { + return tensorflow::gtl::nullopt; + } + + std::vector tile_components = + tensorflow::str_util::Split(it->second, ':'); + CHECK_EQ(tile_components.size(), 3); + + int64 tile_size_m; + int64 tile_size_k; + int64 tile_size_n_in_vector_width; + + CHECK(tensorflow::strings::safe_strto64(tile_components[0], &tile_size_m)); + CHECK(tensorflow::strings::safe_strto64(tile_components[1], &tile_size_k)); + + tensorflow::StringPiece tile_size_n_in_vector_width_str = + RemoveSuffix(tile_components[2], "*vectwidth"); + + CHECK(tensorflow::strings::safe_strto64(tile_size_n_in_vector_width_str, + &tile_size_n_in_vector_width)); + + return std::tuple(tile_size_m, tile_size_k, + tile_size_n_in_vector_width); +} + } // namespace options } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h index 106dfbbc62dfba8d3de74e0a2ae3bb247bd91d67..429b9e16cbdd6f623919533582481f1640118081 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h @@ -29,6 +29,8 @@ bool VectorizedReduceDisabled(const HloModuleConfig& config); bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config); tensorflow::gtl::optional LlvmIrGemvTilingFactor( const HloModuleConfig& config); +tensorflow::gtl::optional> LlvmIrGemmTileSize( + const HloModuleConfig& config); } // namespace options } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index d97802ee45d6add3c466577d7624d9ca74e2f380..b877b295814a7e13569a1837ed3e1787f2fc3f56 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -160,9 +160,8 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, int32 size_32 = static_cast(size); CpuInfeedBuffer* queued_buffer = new CpuInfeedBuffer(size_32); - Status s = - TransferBufferToDevice(executor, /*size=*/size, - /*source=*/source, queued_buffer->device_memory()); + Status s = executor->SynchronousMemcpyH2D( + /*host_src=*/source, /*size=*/size, queued_buffer->device_memory()); if (!s.ok()) { queued_buffer->Done(s); diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index d77076546f404afc1292bc4b5e902b59e24a1246..58228180ca55ede50c8579bbd73cfdfffc07e208 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -324,11 +324,11 @@ void ColumnMajorMatrixVectorProductEmitter::Emit() { int64 column_remainder = k() % tile_cols(); int64 column_limit = k() - column_remainder; - ksl_.For("dot.outer.tiled", - /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(), - [&](llvm::Value* column, bool is_first_column) { - EmitOuterLoopBody(column, tile_cols(), is_first_column); - }); + ksl_.ForReturnVoid("dot.outer.tiled", + /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(), + [&](llvm::Value* column, bool is_first_column) { + EmitOuterLoopBody(column, tile_cols(), is_first_column); + }); if (column_remainder != 0) { EmitOuterLoopBody(ir_builder_->getInt64(column_limit), column_remainder, @@ -341,19 +341,20 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( int64 columns, bool is_first_column) { int64 row_limit = m() - (m() % tile_rows()); - ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit, - /*step=*/tile_rows(), [&](llvm::Value* row) { - std::vector lhs_tile = - lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row); - llvm::Value* accumulator = - is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) - : vsl_.GetZeroVector()) - : vsl_.LoadVector(result_, row); - for (int i = 0; i < columns; i++) { - accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); - } - vsl_.StoreVector(accumulator, result_, row); - }); + ksl_.ForReturnVoid( + "dot.inner.tiled", /*start=*/0, /*end=*/row_limit, + /*step=*/tile_rows(), [&](llvm::Value* row) { + std::vector lhs_tile = + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row); + llvm::Value* accumulator = + is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) + : vsl_.GetZeroVector()) + : vsl_.LoadVector(result_, row); + for (int i = 0; i < columns; i++) { + accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); + } + vsl_.StoreVector(accumulator, result_, row); + }); } void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( @@ -372,7 +373,7 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( // // initialized. // } - ksl_.For( + ksl_.ForReturnVoid( "dot.inner.epilg.outer", /*start=*/current_tile_col, /*end=*/ir_builder_->CreateAdd(columns_llvm, current_tile_col), /*step=*/1, /*peel_first_iteration=*/false, @@ -382,7 +383,7 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( ir_builder_->CreateMul(col, ir_builder_->getInt64(m())); llvm::Value* lhs_base_pointer = vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.For( + ksl_.ForReturnVoid( "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(), /*step=*/1, [&](llvm::Value* scalar_row) { llvm::Value* product = vsl_.Mul( @@ -390,7 +391,7 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( llvm::Value* setting_result_first_time = ir_builder_->CreateAnd( is_first_scalar_col, ir_builder_->getInt1(is_first_tiled_column)); - ksl_.If( + ksl_.IfReturnVoid( setting_result_first_time, /*true_block_generator=*/ [&]() { @@ -571,9 +572,10 @@ void RowMajorMatrixVectorProductEmitter::Emit() { int64 row_remainder = m() % tile_rows(); int64 row_limit = m() - row_remainder; - ksl_.For("dot.outer.tiled", - /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), - [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); + ksl_.ForReturnVoid( + "dot.outer.tiled", + /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), + [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); if (row_remainder != 0) { EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder); @@ -585,17 +587,17 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( std::vector* vector_accumulators) { int64 column_limit = k() - (k() % tile_cols()); - ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, - /*step=*/tile_cols(), [&](llvm::Value* col) { - std::vector lhs_tile = - lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col); - llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); - for (int i = 0; i < rows; i++) { - llvm::Value* old_sum = (*vector_accumulators)[i].Get(); - (*vector_accumulators)[i].Set( - vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); - } - }); + ksl_.ForReturnVoid("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, + /*step=*/tile_cols(), [&](llvm::Value* col) { + std::vector lhs_tile = + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col); + llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); + for (int i = 0; i < rows; i++) { + llvm::Value* old_sum = (*vector_accumulators)[i].Get(); + (*vector_accumulators)[i].Set(vsl_.Add( + old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); + } + }); } void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( @@ -612,14 +614,15 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( ir_builder_->getInt64(k())); llvm::Value* lhs_base_pointer = vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(), - /*step=*/1, [&](llvm::Value* scalar_col) { - llvm::Value* product = - vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), - vsl_.LoadScalar(rhs_, scalar_col)); - llvm::Value* old_value = (*scalar_accumulators)[r].Get(); - (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); - }); + ksl_.ForReturnVoid( + "dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(), + /*step=*/1, [&](llvm::Value* scalar_col) { + llvm::Value* product = + vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), + vsl_.LoadScalar(rhs_, scalar_col)); + llvm::Value* old_value = (*scalar_accumulators)[r].Get(); + (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); + }); } } @@ -665,6 +668,10 @@ class MatrixMatrixBlockPanelEmitter { // the largest vector register we will use). This can be larger than the // largest vector register supported by the machine -- LLVM will legalize // these large vector widths into legally sized vectors. + // + // `max_vector_count` is the maximum number of vectors of size + // `max_vectorization_width` that we will attempt to process at once. + // // `min_vectorization_width` is the smallest vector width the emitter will use // -- below that it will devolve to using a scalar loop. // @@ -674,12 +681,13 @@ class MatrixMatrixBlockPanelEmitter { class Config { public: explicit Config(PrimitiveType scalar_type, Dimensions dims, - int64 max_vectorization_width, + int64 max_vectorization_width, int64 max_vector_count, int64 min_vectorization_width, int64 tile_size_m, int64 tile_size_k) : scalar_type_(scalar_type), dims_(dims), max_vectorization_width_(max_vectorization_width), + max_vector_count_(max_vector_count), min_vectorization_width_(min_vectorization_width), tile_size_m_(tile_size_m), tile_size_k_(tile_size_k) {} @@ -694,6 +702,7 @@ class MatrixMatrixBlockPanelEmitter { PrimitiveType scalar_type() const { return scalar_type_; } Dimensions dims() const { return dims_; } int64 max_vectorization_width() const { return max_vectorization_width_; } + int64 max_vector_count() const { return max_vector_count_; } int64 min_vectorization_width() const { return min_vectorization_width_; } int64 tile_size_m() const { return tile_size_m_; } @@ -703,6 +712,7 @@ class MatrixMatrixBlockPanelEmitter { PrimitiveType scalar_type_; Dimensions dims_; int64 max_vectorization_width_; + int64 max_vector_count_; int64 min_vectorization_width_; int64 tile_size_m_; int64 tile_size_k_; @@ -721,39 +731,35 @@ class MatrixMatrixBlockPanelEmitter { ksl_(ir_builder_) { CHECK(max_vectorization_width() > 0 && IsPowerOfTwo(static_cast(max_vectorization_width()))); + CHECK_GT(max_vector_count(), 0); CHECK(min_vectorization_width() > 0 && IsPowerOfTwo(static_cast(min_vectorization_width()))); + CHECK_GE(max_vectorization_width(), min_vectorization_width()); CHECK_GT(tile_size_k(), 0); } void Emit(); private: - // This emits a loop that loops over the `n` dimension in multiples of - // `max_vectorization_width` as much as possible and then emits a remainder - // epilogue. - void EmitLoopOverN(); - - // This emits a loop that loops over the `k` dimension in multiples of - // `tile_size_k` as much as possible and then emits a remainder epilogue. - void EmitLoopOverK(VectorSupportLibrary* vsl, llvm::Value* n_start, - llvm::Value* n_end); - - // This emits a loop that loops over the `m` dimension in multiples of - // `tile_size_m` as much as possible and then emits a remainder epilogue. - void EmitLoopOverM(VectorSupportLibrary* vsl, int64 tile_size_k, + // The HandleResiduesOnX helpers split the iteration space for dimension X + // into a multiple of the tile size on dimension X and an epilogue. These + // helpers ultimately call into `EmitTiledGemm` for emitting the + // tiled GEMM kernel. + + void HandleResiduesOnN(); + void HandleResiduesOnK(VectorSupportLibrary* vsl, llvm::Value* n_start, + llvm::Value* n_end); + void HandleResiduesOnM(VectorSupportLibrary* vsl, int64 tile_size_k, + llvm::Value* k_start, llvm::Value* k_end, + llvm::Value* n_start, llvm::Value* n_end); + + // This emits a tiled GEMM kernel. For a detailed description see the comment + // on the implementation. + void EmitTiledGemm(VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, llvm::Value* k_end, - llvm::Value* n_start, llvm::Value* n_end); - - // This emits the inner reduction loop. This inner reduction loop multiplies - // a tile from the LHS of size [tile_size_m,tile_size_k] and a tile from the - // RHS of size [`tile_size_k`, vls->vector_width()] to update a tile of size - // [`tile_size_m`, vls->vector_width()] in the result. - void EmitTiledReductionLoop(VectorSupportLibrary* vsl, int64 tile_size_k, - llvm::Value* k_start, llvm::Value* k_end, - llvm::Value* n_start, llvm::Value* n_end, - int64 tile_size_m, llvm::Value* m_start, - llvm::Value* m_end); + llvm::Value* n_start, llvm::Value* n_end, + int64 tile_size_m, llvm::Value* m_start, + llvm::Value* m_end); llvm::Value* GetInt64(int64 value) { return ir_builder_->getInt64(value); } @@ -763,6 +769,7 @@ class MatrixMatrixBlockPanelEmitter { int64 max_vectorization_width() const { return config().max_vectorization_width(); } + int64 max_vector_count() const { return config().max_vector_count(); } int64 min_vectorization_width() const { return config().min_vectorization_width(); } @@ -779,16 +786,19 @@ class MatrixMatrixBlockPanelEmitter { KernelSupportLibrary ksl_; }; -void MatrixMatrixBlockPanelEmitter::Emit() { EmitLoopOverN(); } +void MatrixMatrixBlockPanelEmitter::Emit() { HandleResiduesOnN(); } -void MatrixMatrixBlockPanelEmitter::EmitLoopOverN() { +void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() { // We can only iterate the `n` dimension for an extent that is divisible by // the vectorization width. So we emit an outer loop that first processes the // largest extent in `n` that is divisible by max_vectorization_width, then // the largest remaining extent that is divisible by max_vectorization_width / // 2 etc. - int64 current_vectorization_width = max_vectorization_width(); + int64 current_vectorization_width = + max_vector_count() * max_vectorization_width(); + int64 current_vector_count = max_vector_count(); + int64 n_start = 0; while (n_start != dims().n() && current_vectorization_width >= min_vectorization_width()) { @@ -796,53 +806,67 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverN() { if (n_start != n_end) { VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, ir_builder_, "gebp"); - EmitLoopOverK(&vsl, GetInt64(n_start), GetInt64(n_end)); + HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end)); n_start = n_end; } - current_vectorization_width /= 2; + if (current_vector_count == 1) { + current_vectorization_width /= 2; + } else { + current_vector_count--; + current_vectorization_width = + current_vector_count * max_vectorization_width(); + } } if (n_start != dims().n()) { VectorSupportLibrary vsl(scalar_type(), 1, ir_builder_, "gebp"); - ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { + ksl_.ForReturnVoid("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { llvm::Value* n_i_next = ir_builder_->CreateAdd(n_i, ir_builder_->getInt64(1)); - EmitLoopOverK(&vsl, n_i, n_i_next); + HandleResiduesOnK(&vsl, n_i, n_i_next); }); } } -void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl, - llvm::Value* n_start, - llvm::Value* n_end) { +void MatrixMatrixBlockPanelEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl, + llvm::Value* n_start, + llvm::Value* n_end) { int64 k_start = 0; int64 k_end = dims().k() - (dims().k() % tile_size_k()); if (k_end != k_start) { - EmitLoopOverM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end), - n_start, n_end); + HandleResiduesOnM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end), + n_start, n_end); k_start = k_end; } if (k_start != dims().k()) { - EmitLoopOverM(vsl, dims().k() - k_start, GetInt64(k_start), - GetInt64(dims().k()), n_start, n_end); + HandleResiduesOnM(vsl, dims().k() - k_start, GetInt64(k_start), + GetInt64(dims().k()), n_start, n_end); } } -void MatrixMatrixBlockPanelEmitter::EmitLoopOverM( +void MatrixMatrixBlockPanelEmitter::HandleResiduesOnM( VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) { const int64 m_end = dims().m() - dims().m() % tile_size_m(); - EmitTiledReductionLoop(vsl, tile_size_k, k_start, k_end, n_start, n_end, - tile_size_m(), GetInt64(0), GetInt64(m_end)); + EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, tile_size_m(), + GetInt64(0), GetInt64(m_end)); if (m_end != dims().m()) { - EmitTiledReductionLoop(vsl, tile_size_k, k_start, k_end, n_start, n_end, - dims().m() - m_end, GetInt64(m_end), - GetInt64(dims().m())); + EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, + dims().m() - m_end, GetInt64(m_end), GetInt64(dims().m())); } } +// The loop structure is: +// +// Iterate over dimension M as m: +// Iterate over dimension N as n: +// Iterate over dimension K as k: +// OutputTile[m,n] += Dot(LhsTile[m,k], RhsTile[k,n]) +// +// I.e. a just a tiled version of a "naive" GEMM. +// // The tiling scheme is as follows: // // Let the LHS be: @@ -904,41 +928,48 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverM( // +-------------------+-------------------+-------------------+--------- // | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ... // +-------------------+-------------------+-------------------+--------- -void MatrixMatrixBlockPanelEmitter::EmitTiledReductionLoop( +void MatrixMatrixBlockPanelEmitter::EmitTiledGemm( VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) { - ksl_.For("dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) { - MemoryTile result_memory_tile(vsl, ir_builder_, /*matrix=*/result_, - /*matrix_size_along_minor_dim=*/dims().n(), - /*major_dim_offset=*/m_i, - /*tile_size_along_major_dim=*/tile_size_m); - MemoryTile lhs_memory_tile(vsl, ir_builder_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/dims().k(), - /*major_dim_offset=*/m_i, - /*tile_size_along_major_dim=*/tile_size_m); - - ksl_.For("dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) { - MemoryTile rhs_memory_tile(vsl, ir_builder_, rhs_, dims().n(), k_i, - tile_size_k); - std::vector> lhs_tile = - lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k); - ksl_.For( - "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { - std::vector rhs_tile = rhs_memory_tile.LoadTile(n_i); - std::vector result_tile = - result_memory_tile.LoadTile(n_i); - for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) { - for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) { - result_tile[r_m_i] = - vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i], - result_tile[r_m_i]); - } - } - result_memory_tile.StoreTile(result_tile, n_i); - }); - }); - }); + ksl_.ForReturnVoid( + "dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) { + MemoryTile result_memory_tile( + vsl, ir_builder_, /*matrix=*/result_, + /*matrix_size_along_minor_dim=*/dims().n(), + /*major_dim_offset=*/m_i, + /*tile_size_along_major_dim=*/tile_size_m); + MemoryTile lhs_memory_tile(vsl, ir_builder_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/dims().k(), + /*major_dim_offset=*/m_i, + /*tile_size_along_major_dim=*/tile_size_m); + ksl_.ForReturnVoid( + "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { + TileVariable result_tile_var(vsl, + result_memory_tile.LoadTile(n_i)); + ksl_.ForReturnVoid( + "dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) { + MemoryTile rhs_memory_tile(vsl, ir_builder_, rhs_, + dims().n(), k_i, tile_size_k); + std::vector> lhs_tile = + lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k); + std::vector rhs_tile = + rhs_memory_tile.LoadTile(n_i); + std::vector result_tile = + result_tile_var.Get(); + for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) { + for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) { + result_tile[r_m_i] = + vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i], + result_tile[r_m_i]); + } + } + result_tile_var.Set(result_tile); + }); + + result_memory_tile.StoreTile(result_tile_var.Get(), n_i); + }); + }); } } // namespace @@ -1023,16 +1054,21 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( target, ir_builder_->getInt8(0), size_bytes, target_machine_features_.minimum_alignment_for_allocation(size_bytes)); - int64 max_vector_width = + int64 max_target_vector_width = target_machine_features_.vector_register_num_elements( *ir_builder_->GetInsertBlock()->getParent(), primitive_type); + int64 tile_size_m, tile_size_k, tile_size_n_in_vector_width; + std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) = + GetGemmTileSize(); + MatrixMatrixBlockPanelEmitter::Config config( /*scalar_type=*/primitive_type, MatrixMatrixBlockPanelEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, - /*max_vectorization_width=*/max_vector_width, - /*min_vectorization_width=*/std::min(4, max_vector_width), - /*tile_size_m=*/3, /*tile_size_k=*/5); + /*max_vectorization_width=*/max_target_vector_width, + /*max_vector_count=*/tile_size_n_in_vector_width, + /*min_vectorization_width=*/std::min(4, max_target_vector_width), + /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k); VLOG(2) << "Emitting GEBP kernel in LLVM IR with config " << config.GetCacheKey(); @@ -1265,8 +1301,11 @@ Status DotOpEmitter::Emit() { // from messing up the vectorization. std::unique_ptr reduction_loop = loop_nest.AddLoop( 0, lhs_shape.dimensions(lhs_reduction_dimension), "reduction", - /*prevent_unrolling=*/lhs_reduction_along_minor_dimension && - rhs_reduction_along_minor_dimension); + /*unroll_mode=*/ + (lhs_reduction_along_minor_dimension && + rhs_reduction_along_minor_dimension) + ? xla::llvm_ir::UnrollMode::kNoUnroll + : xla::llvm_ir::UnrollMode::kDefaultUnroll); // The final entry in the rhs and lhs indexes is the indvar of the // reduction loop. @@ -1341,7 +1380,7 @@ Status DotOpEmitter::Emit() { // the rhs and lhs indexes with the reduction dimensions removed. The terms // from the rhs index are the lower dimensions in the index so we add them // first. - llvm_ir::IrArray::Index target_index; + llvm_ir::IrArray::Index target_index(lhs_index.GetType()); for (int dimension = 0; dimension < lhs_index.size(); ++dimension) { if (dimension != lhs_reduction_dimension) { target_index.push_back(lhs_index[dimension]); @@ -1365,10 +1404,13 @@ Status DotOpEmitter::Emit() { Status DotOpEmitter::EmitScalarDot() { // A scalar dot is just a scalar multiply. llvm::Value* result; + // Use the same index_type for all tensor accesses in the same kernel. + llvm::Type* index_type = ir_builder_->getInt64Ty(); + llvm_ir::IrArray::Index element_index(index_type); llvm::Value* lhs_value = - lhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_); + lhs_array_.EmitReadArrayElement(/*index=*/element_index, ir_builder_); llvm::Value* rhs_value = - rhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_); + rhs_array_.EmitReadArrayElement(/*index=*/element_index, ir_builder_); if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) { #define REAL(x) ir_builder_->CreateExtractValue(x, {0}) #define IMAG(x) ir_builder_->CreateExtractValue(x, {1}) @@ -1386,7 +1428,8 @@ Status DotOpEmitter::EmitScalarDot() { } else { result = ir_builder_->CreateFMul(lhs_value, rhs_value); } - target_array_.EmitWriteArrayElement(/*index=*/{}, result, ir_builder_); + target_array_.EmitWriteArrayElement(/*index=*/element_index, result, + ir_builder_); return Status::OK(); } @@ -1588,8 +1631,8 @@ bool PotentiallyImplementedAsEigenDot( const Shape& lhs_shape = hlo.operand(0)->shape(); const Shape& rhs_shape = hlo.operand(1)->shape(); - if (ShapeUtil::HasZeroElements(lhs_shape) || - ShapeUtil::HasZeroElements(rhs_shape)) { + if (ShapeUtil::IsZeroElementArray(lhs_shape) || + ShapeUtil::IsZeroElementArray(rhs_shape)) { return false; } diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index d88ccea0dbc845c0d9a580a5b118c57c888fb557..ed2a18976a0f1a88e7bb4632d3a63167d5c146ad 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -143,6 +143,17 @@ class DotOpEmitter { .value_or(kDefaultTilingFactor); } + std::tuple GetGemmTileSize() const { + // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz + // + // TODO(b/80093688): Tune for other architectures and centralize this + // information in one place. + const std::tuple kDefaultTileSize = + std::tuple(11, 9, 1); + return options::LlvmIrGemmTileSize(hlo_module_config_) + .value_or(kDefaultTileSize); + } + // Returns true if we should use an experimental implementation of GEMM // (general matrix matrix multiplication) if possible. bool EnableExperimentalLlvmIrGemm() const { diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index b560b7531c0d24e6f670e61a15dce295d9fa2a49..1a8bedfe6afb4f096ddd4703c312b84d521a7ba5 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -64,8 +64,8 @@ bool PotentiallyImplementedAsEigenConvolution( return false; } - if (ShapeUtil::HasZeroElements(input_shape) || - ShapeUtil::HasZeroElements(kernel_shape)) { + if (ShapeUtil::IsZeroElementArray(input_shape) || + ShapeUtil::IsZeroElementArray(kernel_shape)) { return false; } // Make sure input and kernel has the same data type. diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 59223fddac2f5f7e2e85de4d37e4b6c5760ae697..5c04f381f2db29867d0a8b67a5442b940f44b884 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -226,10 +226,13 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) { // kCopy shallow copies a tuple so just memcpy the top-level buffer. TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy)); return EmitMemcpy(*(copy->operand(0)), *copy); - } else { - // Use the elemental emitter for non-tuple shapes. + } else if (ShapeUtil::IsArray(copy->shape())) { + // Use the elemental emitter for array shapes. return DefaultAction(copy); } + return Unimplemented( + "unsupported operand type %s for copy instruction", + PrimitiveType_Name(copy->shape().element_type()).c_str()); } // Calculate the alignment of a buffer allocated for a given primitive type. @@ -560,7 +563,8 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); - llvm_ir::IrArray::Index input_index(index.size()); + llvm_ir::IrArray::Index input_index(ir_builder_.getInt64Ty(), + index.size()); llvm::Value* in_bounds_condition = nullptr; for (size_t i = 0; i < index.size(); ++i) { llvm::Value* strided_index = ir_builder_.CreateNSWMul( @@ -691,7 +695,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // Compute the operand index to visit and evaluate the condition whether the // operand index is within the bounds. The unsigned comparison includes // checking whether the operand index >= 0. - llvm_ir::IrArray::Index operand_index(source_index.size()); + llvm_ir::IrArray::Index operand_index(ir_builder_.getInt64Ty(), + source_index.size()); llvm::Value* in_bounds_condition = ir_builder_.getTrue(); for (int64 i = 0; i < rank; ++i) { llvm::Value* strided_index = ir_builder_.CreateNSWMul( @@ -765,7 +770,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // value and the current output value. SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), &ir_builder_); - llvm_ir::IrArray::Index selected_index; + llvm_ir::IrArray::Index selected_index(source_index.GetType()); for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP( selected_index_address, {ir_builder_.getInt32(i)}); @@ -1107,7 +1112,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { // We are not in the padding, so carry out the computation. int num_dims = num_spatial_dims + 2; - llvm_ir::IrArray::Index input_index(num_dims); + llvm_ir::IrArray::Index input_index(ir_builder_.getInt64Ty(), num_dims); for (int i = 0; i < num_spatial_dims; ++i) { input_index[dnums.input_spatial_dimensions(i)] = input_spatial[i]; } @@ -1115,7 +1120,8 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { input_index[dnums.input_batch_dimension()] = batch; llvm_ir::IrArray kernel_array(GetIrArrayFor(rhs)); - llvm_ir::IrArray::Index kernel_index(num_dims); + llvm_ir::IrArray::Index kernel_index(ir_builder_.getInt64Ty(), + num_dims); for (int i = 0; i < num_spatial_dims; ++i) { kernel_index[dnums.kernel_spatial_dimensions(i)] = window.dimensions(i).window_reversal() @@ -1682,7 +1688,8 @@ StatusOr IrEmitter::EmitVectorizedReduce( // } llvm_ir::ForLoopNest loop_nest(IrName(reduce), &ir_builder_); - llvm_ir::IrArray::Index array_index(reduce->shape().dimensions_size()); + llvm_ir::IrArray::Index array_index(ir_builder_.getInt64Ty(), + reduce->shape().dimensions_size()); for (int i = LayoutUtil::MinorToMajor(reduce->shape()).size() - 1; i > 0; --i) { int64 dimension = LayoutUtil::Minor(reduce->shape().layout(), i); @@ -1873,7 +1880,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(slice)); - if (ShapeUtil::HasZeroElements(slice->shape())) { + if (ShapeUtil::IsZeroElementArray(slice->shape())) { return Status::OK(); } @@ -2066,7 +2073,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { // Compute the output index the operand element should be assigned to. // output_index := edge_padding_low + operand_index * (interior_padding + 1) const PaddingConfig& padding_config = pad->padding_config(); - llvm_ir::IrArray::Index output_index; + llvm_ir::IrArray::Index output_index(operand_index.GetType()); for (size_t i = 0; i < operand_index.size(); ++i) { llvm::Value* offset = ir_builder_.CreateMul( operand_index[i], @@ -2528,6 +2535,13 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { return Status::OK(); } +Status IrEmitter::HandleGenerateToken(HloInstruction* gen_token) { + TF_RET_CHECK(ByteSizeOf(gen_token->shape()) == 0); + // No code to generate, but we need to emit an address for book-keeping. + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(gen_token)); + return Status::OK(); +} + Status IrEmitter::FinishVisit(HloInstruction* root) { // When this method is called, we should have already emitted an IR value for // the root (return) op. The IR value holds the address of the buffer holding @@ -2809,7 +2823,10 @@ Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { // For the root node, we write directly to the output buffer of the // function. llvm::Argument* retval = compute_function_->result_arg(); - if (!ShapeUtil::IsNil(target_shape)) { + if ((ShapeUtil::IsArray(target_shape) && + !ShapeUtil::IsZeroElementArray(target_shape)) || + (ShapeUtil::IsTuple(target_shape) && + !ShapeUtil::IsEmptyTuple(target_shape))) { llvm::AttrBuilder attr_builder; attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 32c536e18fee86cc60067ba3b25ab1eb0e4233df..e1815c1db7a14dfc90ff646c0fd1e439ffffb2e8 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -150,6 +150,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleWhile(HloInstruction* xla_while) override; Status HandleConcatenate(HloInstruction* concatenate) override; Status HandleConditional(HloInstruction* conditional) override; + Status HandleGenerateToken(HloInstruction* gen_token) override; Status FinishVisit(HloInstruction* root) override; Status Preprocess(HloInstruction* hlo) override; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc index 54af40506dab48b3c2a3a44eb0b5f5fb213a32ec..59ae5acd8b7cea049f09eaf4cc98b41339973c77 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -31,13 +31,15 @@ ParallelLoopEmitter::ParallelLoopEmitter( std::vector ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name) { + tensorflow::StringPiece loop_name, llvm::Type* index_type) { + CHECK_NE(index_type, nullptr); + CHECK(!ShapeUtil::IsTuple(shape_)); CHECK(!ShapeUtil::IsScalar(shape_)); llvm_ir::ForLoopNest loop_nest(loop_name, ir_builder_); const int64 num_dims = shape_.dimensions_size(); - llvm_ir::IrArray::Index array_index(num_dims); + llvm_ir::IrArray::Index array_index(index_type, num_dims); // Add loops from outer-most to inner-most dimensions. for (int i = LayoutUtil::MinorToMajor(shape_).size() - 1; i >= 0; --i) { diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h index 755715634aa70a822b21d25dcae20a8fe053477a..25e182a26d6f21c7eba550020cf17403aa92abf7 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h @@ -61,7 +61,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ~ParallelLoopEmitter() override = default; std::vector EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name) override; + tensorflow::StringPiece loop_name, llvm::Type* index_type) override; private: const DynamicLoopBounds* dynamic_loop_bounds_; diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc index 92da5f71c23d5e1450b39ea8b7bb8345f6fabb3b..f8c8dd5e93d53db8d87be0208b5cf4daac3464f1 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifdef INTEL_MKL +#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML) #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "third_party/intel_mkl_ml/include/mkl_cblas.h" #include "third_party/intel_mkl_ml/include/mkl_service.h" diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc index cd1165e23812861ba9951546b7dd744529232196..c444d151858d3a152a01b99657ffae89ebc6b487 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -427,5 +427,27 @@ llvm::Value* LlvmVariable::Get() const { void LlvmVariable::Set(llvm::Value* new_value) { ir_builder_->CreateStore(new_value, alloca_); } + +TileVariable::TileVariable(VectorSupportLibrary* vector_support, + std::vector initial_value) { + for (llvm::Value* initial_vector_value : initial_value) { + storage_.emplace_back(vector_support, initial_vector_value); + } +} + +std::vector TileVariable::Get() const { + std::vector result; + c_transform(storage_, std::back_inserter(result), + [&](VectorVariable vect_var) { return vect_var.Get(); }); + return result; +} + +void TileVariable::Set(tensorflow::gtl::ArraySlice value) { + CHECK_EQ(value.size(), storage_.size()); + for (int64 i = 0, e = value.size(); i < e; i++) { + storage_[i].Set(value[i]); + } +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index edcaec584997b17dce30b8c46fda4abc78441064..49c2a4e2f4bae9e1672b7d2fe891301bce08bd4b 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace cpu { @@ -317,6 +318,21 @@ class ScalarVariable : public LlvmVariable { Set(initial_value); } }; + +// This wraps a set of alloca-backed stack variables that can, as a whole, store +// a tile. A "tile" is a sequence of vectors that is typically used as a 2D +// grid of scalar values (e.g. for tiled GEMMs). +class TileVariable { + public: + TileVariable(VectorSupportLibrary* vector_support, + std::vector initial_value); + + std::vector Get() const; + void Set(tensorflow::gtl::ArraySlice value); + + private: + std::vector storage_; +}; } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 64678d9d7450974f68817f92526519697a83683c..ee2b455730f8f520db6652f0352f8a96291cac73 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -243,6 +243,8 @@ class DfsHloVisitorBase { virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0; + virtual Status HandleGenerateToken(HloInstructionPtr token) = 0; + // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". virtual Status FinishVisit(HloInstructionPtr root) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 240faebe62f5cee4f61b3c36b5e8f653cfd6db8e..6934e00a4b665e9e6a4302e0c0a8ce1d5bb94373 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -188,6 +188,9 @@ class DfsHloVisitorWithDefaultBase Status HandleGather(HloInstructionPtr gather) override { return DefaultAction(gather); } + Status HandleGenerateToken(HloInstructionPtr token) override { + return DefaultAction(token); + } // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 9a8bab353ef6b1e0b05b250d35296bc3cef8bc37..4ccd85307d9abbb2716d3c603bb3829882d061f0 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -456,17 +456,15 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( llvm::ConstantFP::get(type, 1.0))); } case HloOpcode::kIsFinite: { - // (x == x) && abs(x) != inf + // abs(x) o!= inf, this works because the comparison returns false if + // either operand is NaN. auto type = operand_value->getType(); - auto equal_self = - ir_builder_->CreateFCmpOEQ(operand_value, operand_value); auto abs_value = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::fabs, {operand_value}, {type}, ir_builder_); auto infinity = llvm::ConstantFP::getInfinity(type); auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity); - auto result_i1 = ir_builder_->CreateAnd(equal_self, not_infinite); return ir_builder_->CreateZExt( - result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, module_)); + not_infinite, llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } case HloOpcode::kNegate: return ir_builder_->CreateFNeg(operand_value); @@ -1222,7 +1220,7 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( const Shape& operand_shape = hlo.operand(operand_no)->shape(); // If the operand is scalar, the source index is always {}. if (ShapeUtil::IsScalar(operand_shape)) { - return llvm_ir::IrArray::Index(); + return llvm_ir::IrArray::Index(target_index.GetType()); } // If no implicit broadcast is needed for this operand, returns the target @@ -1234,13 +1232,13 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( // If implicit broadcast is needed, the source dimensions that are broadcast // have index 0. CHECK_EQ(ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(hlo.shape())); - llvm_ir::IrArray::Index source_index; + llvm_ir::IrArray::Index source_index(target_index.GetType()); for (int64 i = 0; i < ShapeUtil::Rank(hlo.shape()); ++i) { if (hlo.shape().dimensions(i) == operand_shape.dimensions(i)) { source_index.push_back(target_index[i]); } else { CHECK_EQ(1, operand_shape.dimensions(i)); - source_index.push_back(ir_builder_->getInt64(0)); + source_index.push_back(target_index.GetConstantWithIndexType(0)); } } return source_index; @@ -1542,9 +1540,14 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( // Emit IR to read dynamic start indices from hlo->operand(1). const HloInstruction* input_hlo = hlo->operand(0); const int64 rank = ShapeUtil::Rank(input_hlo->shape()); - llvm_ir::IrArray::Index slice_start_index(rank); + // Use the same index type for all tensor accesses in the same kernel. + llvm::Type* index_type = index.GetType(); + llvm_ir::IrArray::Index slice_start_index(index_type, rank); for (int64 i = 0; i < rank; ++i) { - llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_type, c); + }; + llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, operand_to_generator.at(hlo->operand(1))(dim_index)); @@ -1554,17 +1557,17 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( // TODO(b/74360564): This is implementation defined behavior, but is // currently respected by all implementations. Change this if we ever decide // to oficially document different behavior. - start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value, - index[i]->getType()); - llvm::Value* operand_dim_size = llvm::ConstantInt::get( - start_index_value->getType(), input_hlo->shape().dimensions(i)); - llvm::Value* output_dim_size = llvm::ConstantInt::get( - start_index_value->getType(), hlo->shape().dimensions(i)); + start_index_value = + ir_builder_->CreateSExtOrTrunc(start_index_value, index_type); + llvm::Value* operand_dim_size = + index_typed_const(input_hlo->shape().dimensions(i)); + llvm::Value* output_dim_size = + index_typed_const(hlo->shape().dimensions(i)); start_index_value = EmitIntegralMin( ir_builder_->CreateSub(operand_dim_size, output_dim_size), - EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0), - start_index_value, /*is_signed=*/true), + EmitIntegralMax(index_typed_const(0), start_index_value, + /*is_signed=*/true), /*is_signed=*/true); start_index_value->setName( @@ -1572,7 +1575,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( slice_start_index[i] = start_index_value; } - llvm_ir::IrArray::Index input_index(rank); + llvm_ir::IrArray::Index input_index(index_type, rank); for (int64 i = 0; i < rank; ++i) { // Emit IR which computes: // input_index = start_index + offset_index @@ -1596,17 +1599,18 @@ StatusOr ElementalIrEmitter::EmitElementalGather( const llvm_ir::ElementGenerator& indices_generator = operand_to_generator.at(hlo->operand(1)); + llvm::Type* index_type = index.GetType(); // This is the index into `operand` that holds the element we want to // generate. This index "unsafe" as in the components in here may be // out of bounds. - IrArray::Index unsafe_operand_index; + IrArray::Index unsafe_operand_index(index_type); // First copy in the window indices to unsafe_operand_index. for (int64 i = 0, e = operand_shape.dimensions_size(), unsafe_operand_index_dim = 0; i < e; i++) { if (c_binary_search(dim_numbers.elided_window_dims(), i)) { - unsafe_operand_index.push_back(ir_builder_->getInt64(0)); + unsafe_operand_index.push_back(index.GetConstantWithIndexType(0)); } else { unsafe_operand_index.push_back( index[dim_numbers.output_window_dims(unsafe_operand_index_dim++)]); @@ -1614,7 +1618,7 @@ StatusOr ElementalIrEmitter::EmitElementalGather( } // This is the index of the index vector in the gather_indices tensor. - IrArray::Index gather_index_index; + IrArray::Index gather_index_index(index_type); { std::vector gather_index_index_components; for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) { @@ -1630,8 +1634,8 @@ StatusOr ElementalIrEmitter::EmitElementalGather( auto add_to_unsafe_operand_index = [&](llvm::Value* index_component, int64 dim) { - llvm::Value* gather_dim_component_extended = ir_builder_->CreateSExtOrTrunc( - index_component, ir_builder_->getInt64Ty()); + llvm::Value* gather_dim_component_extended = + ir_builder_->CreateSExtOrTrunc(index_component, index_type); unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)] = ir_builder_->CreateAdd( unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)], @@ -1647,18 +1651,18 @@ StatusOr ElementalIrEmitter::EmitElementalGather( indices_shape.dimensions(dim_numbers.index_vector_dim()); for (int64 i = 0; i < index_vector_size; i++) { gather_index_index[dim_numbers.index_vector_dim()] = - ir_builder_->getInt64(i); + index.GetConstantWithIndexType(i); TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, indices_generator(gather_index_index)); add_to_unsafe_operand_index(gather_dim_component, i); } } - IrArray::Index safe_operand_index; + IrArray::Index safe_operand_index(index_type); for (int64 i = 0, e = unsafe_operand_index.size(); i < e; i++) { safe_operand_index.push_back(ir_builder_->CreateURem( unsafe_operand_index[i], - ir_builder_->getInt64(operand_shape.dimensions(i)))); + index.GetConstantWithIndexType(operand_shape.dimensions(i)))); } return operand_generator(safe_operand_index); @@ -1673,14 +1677,18 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( const HloInstruction* start_hlo = hlo->operand(2); // Calculate slice start/end indices. const int64 rank = ShapeUtil::Rank(input_hlo->shape()); - llvm_ir::IrArray::Index slice_start_index(rank); - llvm_ir::IrArray::Index slice_limit_index(rank); + llvm_ir::IrArray::Index slice_start_index(index.GetType(), rank); + llvm_ir::IrArray::Index slice_limit_index(index.GetType(), rank); // Slice intersection gathers (ANDs) conditions on all ranks for which // 'input' is set to 'update' llvm::Value* slice_intersection = ir_builder_->getTrue(); for (int64 i = 0; i < rank; ++i) { - llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); + llvm::Type* index_type = index[0]->getType(); + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_type, c); + }; + llvm_ir::IrArray::Index dim_index(1, index_typed_const(i)); TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, operand_to_generator.at(start_hlo)(dim_index)); @@ -1690,18 +1698,18 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // TODO(b/74360564): This is implementation defined behavior, but is // currently respected by all implementations. Change this if we ever decide // to oficially document different behavior. - start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value, - index[i]->getType()); - llvm::Value* input_dim_size = llvm::ConstantInt::get( - index[i]->getType(), input_hlo->shape().dimensions(i)); - llvm::Value* update_dim_size = llvm::ConstantInt::get( - index[i]->getType(), update_hlo->shape().dimensions(i)); - - start_index_value = EmitIntegralMin( - ir_builder_->CreateSub(input_dim_size, update_dim_size), - EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0), - start_index_value, /*is_signed=*/true), - /*is_signed=*/true); + start_index_value = + ir_builder_->CreateSExtOrTrunc(start_index_value, index_type); + llvm::Value* input_dim_size = + index_typed_const(input_hlo->shape().dimensions(i)); + llvm::Value* update_dim_size = + index_typed_const(update_hlo->shape().dimensions(i)); + + start_index_value = + EmitIntegralMin(ir_builder_->CreateSub(input_dim_size, update_dim_size), + EmitIntegralMax(index_typed_const(0), start_index_value, + /*is_signed=*/true), + /*is_signed=*/true); start_index_value->setName( AsStringRef(IrName(hlo, StrCat("start_idx", i)))); @@ -1731,7 +1739,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // Handle true BB (return data from 'update') SetToFirstInsertPoint(if_data.true_block, ir_builder_); // Compute update index for intersection case. - llvm_ir::IrArray::Index update_index(rank); + llvm_ir::IrArray::Index update_index(index.GetType(), rank); for (int64 i = 0; i < rank; ++i) { update_index[i] = ir_builder_->CreateSub(index[i], slice_start_index[i]); } @@ -1799,7 +1807,8 @@ StatusOr ElementalIrEmitter::EmitElementalPad( SetToFirstInsertPoint(if_data.false_block, ir_builder_); TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, - operand_to_generator.at(hlo->operand(1))({})); + operand_to_generator.at(hlo->operand(1))( + IrArray::Index(index.GetType()))); ir_builder_->CreateStore(padding_value, ret_value_addr); SetToFirstInsertPoint(if_data.after_block, ir_builder_); @@ -1826,10 +1835,15 @@ StatusOr ElementalIrEmitter::EmitElementalDot( int64 lhs_dims = hlo->operand(0)->shape().dimensions_size(); int64 rhs_dims = hlo->operand(1)->shape().dimensions_size(); - std::unique_ptr inner_loop = llvm_ir::ForLoop::EmitForLoop( - IrName(hlo, "inner"), ir_builder_->getInt64(0), - ir_builder_->getInt64(contracted_dim_size), ir_builder_->getInt64(1), - ir_builder_); + llvm::Type* index_type = dot_result_index[0]->getType(); + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_type, c); + }; + + std::unique_ptr inner_loop = + llvm_ir::ForLoop::EmitForLoop(IrName(hlo, "inner"), index_typed_const(0), + index_typed_const(contracted_dim_size), + index_typed_const(1), ir_builder_); SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), ir_builder_); PrimitiveType primitive_type = hlo->shape().element_type(); @@ -1848,7 +1862,7 @@ StatusOr ElementalIrEmitter::EmitElementalDot( // Given an output index [a,b,c,d,e] in the result, we compute: // sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T)) - IrArray::Index lhs_index, rhs_index; + IrArray::Index lhs_index(index_type), rhs_index(index_type); for (int64 i = 0; i < lhs_dims - 1; i++) { lhs_index.push_back(dot_result_index[i]); diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 8119478ce934da06969024905e5e054e0b509b03..7cf2746947846d1fa729a34b324dd442143728f1 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -116,6 +116,11 @@ StatusOr Executable::ExecuteOnStreamWrapper( if (profile->compute_time_ns() == 0) { profile->set_compute_time_ns(profile->compute_and_transfer_time_ns()); } + + const int64 executable_size_in_bytes = SizeInBytes(); + if (executable_size_in_bytes != 0) { + profile->set_executable_size_in_bytes(executable_size_in_bytes); + } } if (profile_ptr != nullptr) { @@ -129,19 +134,7 @@ StatusOr Executable::ExecuteOnStreamWrapper( return return_value; } -Status Executable::DumpSessionModule() { - TF_RET_CHECK(dumping()); - const string& directory_path = - module_config().debug_options().xla_dump_executions_to(); - VersionedComputationHandle versioned_handle = entry_computation_handle(); - // This filename does not include the version number because the computation - // is only ever executed at one version. - string filename = tensorflow::strings::Printf( - "computation_%lld__%s__execution_%lld", versioned_handle.handle.handle(), - session_module_->entry().name().c_str(), ++execution_count_); - return Executable::DumpToDirectory(directory_path, filename, - *session_module_); -} +int64 Executable::SizeInBytes() { return -1; } Status Executable::DumpHloSnapshot() { TF_RET_CHECK(dumping_snapshot()); @@ -156,26 +149,6 @@ Status Executable::DumpHloSnapshot() { return Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot_); } -/* static */ Status Executable::DumpToDirectory( - const string& directory_path, string filename, - const SessionModule& session_module) { - tensorflow::Env* env = tensorflow::Env::Default(); - if (!env->IsDirectory(directory_path).ok()) { - // NB! CreateDir does not work reliably with multiple XLA threads -- two - // threads can race to observe the absence of the dump directory and - // simultaneously try to create it, causing the "losing" thread to get a - // "directory already exists" error. - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory_path)); - } - filename = SanitizeFileName(std::move(filename)); - string file_path = tensorflow::io::JoinPath(directory_path, filename); - string result; - TF_RET_CHECK( - tensorflow::SerializeToStringDeterministic(session_module, &result)); - return tensorflow::WriteStringToFile(tensorflow::Env::Default(), file_path, - result); -} - /* static */ Status Executable::DumpToDirectory( const string& directory_path, string filename, const HloSnapshot& hlo_session) { diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 4f0466c544738fa1ec4602ee5104daee8d969c83..98eaeee30a693211ae564a5ef3c373f0364bef11 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -27,9 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -90,8 +88,7 @@ class Executable { // called explicitly for other (async, for example) variants after the stream // has completed. virtual Status PopulateExecutionProfile( - HloExecutionProfile* hlo_execution_profile, - se::StreamExecutor* executor) { + HloExecutionProfile* hlo_execution_profile, se::Stream* stream) { return Status::OK(); } @@ -132,25 +129,15 @@ class Executable { const HloModuleConfig& module_config() const { return hlo_module_->config(); } - // Returns the versioned computation handle of the computation computed by - // this executable. - const VersionedComputationHandle& entry_computation_handle() const { - return hlo_module_->entry_computation_handle(); - } - // The shape (including layout) that results from this execution. This is the // shape of the DeviceMemoryBase result value in ExecuteOnStream above. - const Shape& host_result_shape() const { - return hlo_module_->config().host_entry_computation_layout().result_shape(); + const Shape& result_shape() const { + return hlo_module_->config().entry_computation_layout().result_shape(); } - // TODO(b/74197823): Delete the session module dumping helpers. - void set_session_module(std::unique_ptr session_module) { - session_module_ = std::move(session_module); - } - bool dumping() const { return session_module_ != nullptr; } - SessionModule* session_module() const { return session_module_.get(); } - Status DumpSessionModule(); + // Returns the size of the executable in bytes. Returns -1 by default if the + // method is not overridden to support this kind of query. + virtual int64 SizeInBytes(); // Dumping helpers. void set_hlo_snapshot(std::unique_ptr hlo_snapshot) { @@ -160,10 +147,6 @@ class Executable { HloSnapshot* hlo_snapshot() const { return hlo_snapshot_.get(); } Status DumpHloSnapshot(); - // Dump session_module to directory_path/filename. - static Status DumpToDirectory(const string& directory_path, string filename, - const SessionModule& session_module); - // Dump hlo snapshot to directory_path/filename. static Status DumpToDirectory(const string& directory_path, string filename, const HloSnapshot& hlo_session); @@ -179,9 +162,6 @@ class Executable { // around. const std::unique_ptr hlo_module_; - // SessionModule this was compiled from. Null if not dumping executions. - std::unique_ptr session_module_; - // HloSnapshot this was compiled from. Null if not dumping executions. std::unique_ptr hlo_snapshot_; diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index 2d3e4b1fcdf6675955714cab262a8b2ca8ff4297..7cd2c9c136acac46e8e6c548c9e58b9bc8e6e0d2 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -300,7 +300,7 @@ static StatusOr PermuteGatherAndWindowDims( StatusOr GatherExpander::ExpandGather( HloInstruction* gather_instr) { - CHECK(!ShapeUtil::HasZeroElements(gather_instr->shape())); + CHECK(!ShapeUtil::IsZeroElementArray(gather_instr->shape())); HloComputation* computation = gather_instr->parent(); HloInstruction* operand = gather_instr->mutable_operand(0); @@ -369,7 +369,7 @@ StatusOr GatherExpander::Run(HloModule* module) { return inst->opcode() == HloOpcode::kGather && // Avoid expanding gather ops that produce zero sized tensors, // instead punt these to ZeroSizedHloElimination. - !ShapeUtil::HasZeroElements(inst->shape()); + !ShapeUtil::IsZeroElementArray(inst->shape()); }; std::vector gather_instrs; diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 5ee67ccb4ae147683c7b41941670c6fc413a0d09..85e28a0dfe38415974e435106a2d0b75863f2df5 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -43,7 +43,7 @@ se::Platform::Id GenericTransferManager::PlatformId() const { } Status GenericTransferManager::WriteSingleTupleIndexTable( - se::StreamExecutor* executor, + se::Stream* stream, tensorflow::gtl::ArraySlice elements, const Shape& shape, se::DeviceMemoryBase* region) { TF_RET_CHECK(elements.size() == ShapeUtil::TupleElementCount(shape)); @@ -52,12 +52,24 @@ Status GenericTransferManager::WriteSingleTupleIndexTable( for (const se::DeviceMemoryBase& element : elements) { element_pointers.push_back(element.opaque()); } - return TransferBufferToDevice(executor, GetByteSizeRequirement(shape), - element_pointers.data(), region); + TF_RETURN_IF_ERROR(TransferBufferToDevice( + stream, GetByteSizeRequirement(shape), element_pointers.data(), region)); + // Ensure the buffer is transferred before we destroy element_pointers. + return stream->BlockHostUntilDone(); +} + +void GenericTransferManager::TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + std::function>)> done) { + Status status = stream->BlockHostUntilDone(); + if (!status.ok()) { + return done(status); + } + done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer)); } StatusOr> -GenericTransferManager::TransferLiteralFromDevice( +GenericTransferManager::TransferLiteralFromDeviceInternal( se::StreamExecutor* executor, const ShapedBuffer& device_buffer) { VLOG(2) << "transferring literal from device ordinal " << executor->device_ordinal() << "; device buffer: " << device_buffer; @@ -74,9 +86,8 @@ GenericTransferManager::TransferLiteralFromDevice( TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_host_shape(), [&](const Shape& subshape, const ShapeIndex& index) -> Status { - if (!ShapeUtil::IsTuple(subshape)) { - TF_RETURN_IF_ERROR(TransferBufferFromDevice( - executor, + if (ShapeUtil::IsArray(subshape)) { + TF_RETURN_IF_ERROR(executor->SynchronousMemcpyD2H( /*source=*/device_buffer.buffer(index), /*size=*/GetByteSizeRequirement(subshape), /*destination=*/ @@ -88,8 +99,8 @@ GenericTransferManager::TransferLiteralFromDevice( return std::move(literal); } -Status GenericTransferManager::TransferLiteralToDevice( - se::StreamExecutor* executor, const LiteralSlice& literal, +Status GenericTransferManager::TransferLiteralToDeviceAsync( + se::Stream* stream, const LiteralSlice& literal, const ShapedBuffer& device_buffer) { const Shape& shape = literal.shape(); VLOG(2) << "transferring literal shape to device: " @@ -103,9 +114,10 @@ Status GenericTransferManager::TransferLiteralToDevice( TF_RET_CHECK( ShapeUtil::Compatible(literal.shape(), device_buffer.on_host_shape())); - TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); + TF_RET_CHECK(stream->parent()->device_ordinal() == + device_buffer.device_ordinal()); - TF_RETURN_IF_ERROR(WriteTupleIndexTables(executor, device_buffer)); + TF_RETURN_IF_ERROR(WriteTupleIndexTables(stream, device_buffer)); return ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_host_shape(), @@ -121,16 +133,21 @@ Status GenericTransferManager::TransferLiteralToDevice( if (LayoutUtil::Equal(device_subshape.layout(), subliteral.shape().layout())) { source = subliteral.untyped_data(); + return TransferBufferToDevice( + stream, + /*size=*/GetByteSizeRequirement(device_subshape), source, + &device_memory); } else { // Relayout data before transferring. relayed_out_literal = subliteral.Relayout(device_subshape.layout(), /*shape_index=*/{}); source = relayed_out_literal->untyped_data(); + TF_RETURN_IF_ERROR(TransferBufferToDevice( + stream, + /*size=*/GetByteSizeRequirement(device_subshape), source, + &device_memory)); + return stream->BlockHostUntilDone(); } - return TransferBufferToDevice( - executor, - /*size=*/GetByteSizeRequirement(device_subshape), source, - &device_memory); } return Status::OK(); }); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 3da9570ef7eebcdf618439f628fb4d5589993e4f..d216fe7d29e8f2e84ab4f558ee5caec32d07a70a 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -41,12 +41,13 @@ class GenericTransferManager : public TransferManager { se::Platform::Id PlatformId() const override; - StatusOr> TransferLiteralFromDevice( - se::StreamExecutor* executor, const ShapedBuffer& device_buffer) override; + void TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + std::function>)> done) override; - Status TransferLiteralToDevice(se::StreamExecutor* executor, - const LiteralSlice& literal, - const ShapedBuffer& device_buffer) override; + Status TransferLiteralToDeviceAsync( + se::Stream* stream, const LiteralSlice& literal, + const ShapedBuffer& device_buffer) override; Status TransferLiteralToInfeed(se::StreamExecutor* executor, const LiteralSlice& literal) override; @@ -64,11 +65,14 @@ class GenericTransferManager : public TransferManager { const void* source) override; Status WriteSingleTupleIndexTable( - se::StreamExecutor* executor, + se::Stream* stream, tensorflow::gtl::ArraySlice elements, const Shape& shape, se::DeviceMemoryBase* region) override; private: + StatusOr> TransferLiteralFromDeviceInternal( + se::StreamExecutor* executor, const ShapedBuffer& device_buffer); + // The platform this transfer manager targets. const se::Platform::Id platform_id_; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 6bd9d4c31df5b76820abcb711f910b7c468c057d..af6d298589eb58fbae96158bd264c2b085cb66d1 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -164,6 +164,7 @@ cc_library( "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", @@ -236,6 +237,19 @@ cc_library( ], ) +cc_library( + name = "hlo_execution_profiler", + srcs = ["hlo_execution_profiler.cc"], + hdrs = ["hlo_execution_profiler.h"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_execution_profile", + "//tensorflow/compiler/xla/service:pool", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + cc_library( name = "gpu_executable", srcs = [ @@ -277,6 +291,7 @@ cc_library( ":backend_configs", ":buffer_allocations", ":cudnn_convolution_runner", + ":hlo_execution_profiler", ":infeed_manager", ":ir_emission_utils", ":partition_assignment", @@ -422,6 +437,34 @@ tf_cc_test( ], ) +cc_library( + name = "multi_output_fusion", + srcs = ["multi_output_fusion.cc"], + hdrs = ["multi_output_fusion.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:multi_output_fusion", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "multi_output_fusion_test", + srcs = ["multi_output_fusion_test.cc"], + deps = [ + ":multi_output_fusion", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + ], +) + cc_library( name = "gpu_copy_insertion", srcs = ["gpu_copy_insertion.cc"], @@ -522,6 +565,7 @@ cc_library( ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", + ":multi_output_fusion", ":pad_insertion", ":partition_assignment", ":stream_assignment", @@ -539,7 +583,6 @@ cc_library( "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", - "//tensorflow/compiler/xla/service:gather_expander", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", @@ -569,7 +612,6 @@ cc_library( "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", "@llvm//:core", - "@llvm//:support", ], alwayslink = True, # Contains compiler registration ) diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index e0c73aa73acb7f3313eb54fb07390cb76590433e..f9dccd287d955502858f6c24ccd4de80256fc148 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -42,8 +42,8 @@ bool CanImplementAsCudnnForwardConv(HloInstruction* conv) { } // CuDNN does not accept zero-element arguments - if (ShapeUtil::HasZeroElements(conv->operand(0)->shape()) || - ShapeUtil::HasZeroElements(conv->operand(1)->shape())) { + if (ShapeUtil::IsZeroElementArray(conv->operand(0)->shape()) || + ShapeUtil::IsZeroElementArray(conv->operand(1)->shape())) { return false; } diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index e5e2a0478a0659986ddec8d6785827b14b9efb56..27d2c3e491bfc2108cbd168d1a5e1575c2eed11f 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -53,11 +53,17 @@ using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; using tensorflow::strings::StrAppend; +namespace { // Returns whether operand is a floating-point literal with the given value. bool IsFPLiteralWithValue(const HloInstruction* operand, float value) { - return operand->opcode() == HloOpcode::kConstant && - operand->literal().IsAllFloat(value); + if (operand->opcode() == HloOpcode::kConstant && + operand->literal().IsAllFloat(value)) { + return true; + } + return operand->opcode() == HloOpcode::kBroadcast && + IsFPLiteralWithValue(operand->operand(0), value); } +} // namespace GpuElementalIrEmitter::GpuElementalIrEmitter( const HloModuleConfig& hlo_module_config, llvm::Module* module, @@ -370,11 +376,17 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( "reduce_window_accum_ptr", ir_builder_); { TF_ASSIGN_OR_RETURN(llvm::Value * init_value, - operand_to_generator.at(hlo->operand(1))({})); + operand_to_generator.at(hlo->operand(1))( + IrArray::Index(index.GetType()))); ir_builder_->CreateStore(init_value, accum_ptr); } - llvm_ir::ForLoopNest loops(IrName(hlo), ir_builder_); + llvm::Type* index_type = index.GetType(); + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return index.GetConstantWithIndexType(c); + }; + + llvm_ir::ForLoopNest loops(IrName(hlo), ir_builder_, index_type); std::vector window_size; for (const auto& dim : window.dimensions()) { window_size.push_back(dim.size()); @@ -385,14 +397,14 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), ir_builder_); - IrArray::Index input_index(index.size()); + IrArray::Index input_index(index_type, index.size()); llvm::Value* in_bounds = ir_builder_->getInt1(true); for (size_t i = 0; i < index.size(); ++i) { llvm::Value* stridden_index = ir_builder_->CreateNSWMul( - index[i], ir_builder_->getInt64(window.dimensions(i).stride())); + index[i], index_typed_const(window.dimensions(i).stride())); input_index[i] = ir_builder_->CreateNSWSub( ir_builder_->CreateNSWAdd(stridden_index, window_index[i]), - ir_builder_->getInt64(window.dimensions(i).padding_low())); + index_typed_const(window.dimensions(i).padding_low())); // We must check whether 0 ≤ input_index[i] < bound, as otherwise // we are in the pad and so can skip the computation. This @@ -403,7 +415,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( in_bounds, ir_builder_->CreateICmpULT( input_index[i], - ir_builder_->getInt64(operand->shape().dimensions(i)))); + index_typed_const(operand->shape().dimensions(i)))); } llvm_ir::LlvmIfData if_data = @@ -429,11 +441,13 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( llvm::Value* accum_ptr = ir_builder()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType( hlo->shape().element_type(), module_)); + llvm::Type* index_type = output_index.GetType(); TF_ASSIGN_OR_RETURN(llvm::Value * init_value, - operand_to_generator.at(hlo->operand(1))({})); + operand_to_generator.at(hlo->operand(1))( + IrArray::Index(index_type))); ir_builder()->CreateStore(init_value, accum_ptr); - llvm_ir::ForLoopNest loops(IrName(hlo), ir_builder_); + llvm_ir::ForLoopNest loops(IrName(hlo), ir_builder_, index_type); IrArray::Index input_index = loops.AddLoopsForShapeOnDimensions( operand->shape(), hlo->dimensions(), "reduction_dim"); if (!ShapeUtil::IsScalar(hlo->shape())) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index b85721980715e2ce2cd7a689ab12a6cea55ba3f1..decfc40dafafe875fa02bab6695f5c54e522f267 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" -#include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" @@ -52,6 +51,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" @@ -159,16 +159,10 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, if (hlo_module->config().debug_options().xla_gpu_use_cudnn_batchnorm()) { pass.AddPass(); } - // TODO(kramerb): Remove use_fusion once instruction fusion can create - // multi-output fusions from the unfused expander output. pass.AddPass( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, - /*rewrite_grad_op=*/true, - /*use_fusion=*/true); - - // Rewrite gather ops into smaller ones. - pass.AddPass(); + /*rewrite_grad_op=*/true); // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. @@ -211,7 +205,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { HloPassPipeline pipeline("layout_assignment"); pipeline.AddPass( - hlo_module->mutable_device_entry_computation_layout(), stream_exec); + hlo_module->mutable_entry_computation_layout(), stream_exec); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. @@ -261,6 +255,9 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, fusion.AddPass(/*may_duplicate=*/false); fusion.AddPass(/*may_duplicate=*/true); fusion.AddPass(); + fusion.AddPass(); + fusion.AddPass(/*is_layout_sensitive=*/true, + /*only_fusion_computations=*/true); TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); HloPassPipeline reduce_pipeline("reduce-precision"); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 25d8f720ea4791a4c94efcad6909cd0c113fbe70..f20a828bc1a31ad15298a1d77cd79599aa12faf4 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" @@ -41,77 +41,6 @@ namespace { using tensorflow::tracing::ScopedAnnotation; -// A helper class for profiling HLO in the course of GPU program execution. -// All of the profiling is guarded internally, to avoid the caller needing to -// have lots of conditionals sprinkled around. -class HloExecutionProfiler { - public: - // If profiling is enabled, start an execution timer running. - explicit HloExecutionProfiler( - bool do_profile, HloExecutionProfile* profile, se::Stream* stream, - const std::vector::SmartPtr>& sub_streams, - const HloComputation* computation) - : do_profile_(do_profile), - profile_(profile), - stream_(stream), - sub_streams_(sub_streams), - computation_(computation) { - if (do_profile_) { - clock_rate_ghz_ = - stream->parent()->GetDeviceDescription().clock_rate_ghz(); - execution_timer_.reset(new se::Timer(stream->parent())); - per_op_timer_.reset(new se::Timer(stream->parent())); - stream->InitTimer(execution_timer_.get()) - .ThenStartTimer(execution_timer_.get()); - stream->InitTimer(per_op_timer_.get()); - } - } - - // If profiling is enabled, sets the total cycle count on the profile from the - // execution timer. - void FinishExecution() { - CHECK(!finished_execution_) << "Call FinishExecution only once!"; - finished_execution_ = true; - if (do_profile_) { - stream_->ThenWaitFor(&sub_streams_); - stream_->ThenStopTimer(execution_timer_.get()); - stream_->BlockHostUntilDone().IgnoreError(); - profile_->set_total_cycles_executed( - *computation_, execution_timer_->Nanoseconds() * clock_rate_ghz_); - } - } - - // If profiling is enabled, starts the per-operation timer. - void StartOperation() { - if (do_profile_) { - stream_->ThenStartTimer(per_op_timer_.get()); - } - } - - // If profiling is enabled, stops the per-operation timer and records the time - // that the hlo_instruction took to execute in the profile. - void FinishOperation(const HloInstruction* hlo_instruction) { - if (do_profile_) { - stream_->ThenWaitFor(&sub_streams_); - stream_->ThenStopTimer(per_op_timer_.get()); - stream_->BlockHostUntilDone().IgnoreError(); - profile_->SetCyclesTakenBy( - hlo_instruction, per_op_timer_->Nanoseconds() * clock_rate_ghz_); - } - } - - private: - const bool do_profile_; - double clock_rate_ghz_; - HloExecutionProfile* profile_; - se::Stream* stream_; - const std::vector::SmartPtr>& sub_streams_; - const HloComputation* computation_; - std::unique_ptr execution_timer_; - std::unique_ptr per_op_timer_; - bool finished_execution_ = false; -}; - } // namespace // Implementation note: HLO profiling is always enabled for GPU executables, diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc new file mode 100644 index 0000000000000000000000000000000000000000..daddd3738e4bb54f3695a96f6f9ffb9accabe97c --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc @@ -0,0 +1,82 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +HloExecutionProfiler::HloExecutionProfiler( + bool do_profile, HloExecutionProfile* profile, se::Stream* stream, + const std::vector::SmartPtr>& sub_streams, + const HloComputation* computation) + : do_profile_(do_profile), + profile_(profile), + stream_(stream), + sub_streams_(sub_streams), + computation_(computation) { + if (do_profile_) { + clock_rate_ghz_ = stream->parent()->GetDeviceDescription().clock_rate_ghz(); + execution_timer_.reset(new se::Timer(stream->parent())); + per_op_timer_.reset(new se::Timer(stream->parent())); + stream->InitTimer(execution_timer_.get()) + .ThenStartTimer(execution_timer_.get()); + stream->InitTimer(per_op_timer_.get()); + } +} + +void HloExecutionProfiler::FinishExecution() { + CHECK(!finished_execution_) << "Call FinishExecution only once!"; + finished_execution_ = true; + if (do_profile_) { + stream_->ThenWaitFor(&sub_streams_); + stream_->ThenStopTimer(execution_timer_.get()); + stream_->BlockHostUntilDone().IgnoreError(); + profile_->set_total_cycles_executed( + *computation_, + static_cast(execution_timer_->Nanoseconds() * clock_rate_ghz_)); + } +} + +void HloExecutionProfiler::StartOperation() { + if (do_profile_) { + stream_->ThenStartTimer(per_op_timer_.get()); + } +} + +void HloExecutionProfiler::FinishOperation( + const HloInstruction* hlo_instruction) { + if (do_profile_) { + stream_->ThenWaitFor(&sub_streams_); + stream_->ThenStopTimer(per_op_timer_.get()); + stream_->BlockHostUntilDone().IgnoreError(); + profile_->SetCyclesTakenBy( + hlo_instruction, + static_cast(per_op_timer_->Nanoseconds() * clock_rate_ghz_)); + } +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..c9b882ff805c45a57f15df4fe79dc34100c0ceff --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h @@ -0,0 +1,68 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_EXECUTION_PROFILER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_EXECUTION_PROFILER_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/pool.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// A helper class for profiling HLO in the course of GPU program execution. +// All of the profiling is guarded internally, to avoid the caller needing to +// have lots of conditionals sprinkled around. +class HloExecutionProfiler { + public: + // If profiling is enabled, start an execution timer running. + explicit HloExecutionProfiler( + bool do_profile, HloExecutionProfile* profile, se::Stream* stream, + const std::vector::SmartPtr>& sub_streams, + const HloComputation* computation); + + // If profiling is enabled, sets the total cycle count on the profile from the + // execution timer. + void FinishExecution(); + + // If profiling is enabled, starts the per-operation timer. + void StartOperation(); + + // If profiling is enabled, stops the per-operation timer and records the time + // that the hlo_instruction took to execute in the profile. + void FinishOperation(const HloInstruction* hlo_instruction); + + private: + const bool do_profile_; + double clock_rate_ghz_; + HloExecutionProfile* profile_; + se::Stream* stream_; + const std::vector::SmartPtr>& sub_streams_; + const HloComputation* computation_; + std::unique_ptr execution_timer_; + std::unique_ptr per_op_timer_; + bool finished_execution_ = false; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_EXECUTION_PROFILER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc index f766f968826d960a8e86308f2395301aaa09f1ae..375709150e08996ea6a40f5e9e66a8f8d9287008 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc @@ -199,7 +199,7 @@ StatusOr> HloSchedule::Build( // concurrency by optimizing for minimal memory usage. TF_ASSIGN_OR_RETURN( schedule->thunk_launch_order_, - CreateMemoryMinimizingSequence( + ScheduleOneComputation( *entry_computation, [pointer_size](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); })); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc index e230d538cc2df826778e8d13eaaaf31ec81c57f0..45f0a1c645b2875cf90d2c11cfb66c3dd855d097 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc @@ -47,8 +47,7 @@ class HloScheduleTest : public HloTestBase { auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); config.set_debug_options(debug_options); - return MakeUnique("test_module", VersionedComputationHandle(), - config); + return MakeUnique("test_module", config); } HloVec RemoveHlo(const HloVec& input, diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 061210352cf12e6802d066d311fd2cb481673f15..d420863b8569771b16a03591b6a0ddd0591f7e2e 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -137,7 +137,7 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte, } llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, - const ShapeIndex& shape_index, + ShapeIndexView shape_index, llvm::Value* ir_value) { llvm::Type* pointee_type = llvm_ir::ShapeToIrType( ShapeUtil::GetSubshape(hlo.shape(), shape_index), module_); @@ -158,7 +158,7 @@ llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value, - const ShapeIndex& shape_index) { + ShapeIndexView shape_index) { VLOG(2) << "Binding " << hlo.ToString(); const Shape& hlo_shape = hlo.shape(); @@ -202,7 +202,7 @@ llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo, << " of " << hlo.ToString(); llvm_ir::IrArray ir_array(base_ptr, ShapeUtil::GetSubshape(hlo.shape(), shape_index)); - alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array); + alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array, shape_index); // The GPU backend emits one kernel per top-level HLO, and LLVM views // execution of one kernel as the "whole program" executed on the GPU. diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index 3d34311b4368d17cb074aaf33c71fc865e96387e..a86e6e78c693ac53bb2c70d88b999a4e1273ecad 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -51,7 +51,7 @@ class HloToIrBindings { // Rebinds the given HLO to the LLVM IR value that represent its address. void BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value, - const ShapeIndex& shape_index = {}); + ShapeIndexView shape_index = {}); // Unbinds all IR values that's defined in an LLVM function, e.g., function // arguments and stack variables. Global variables will be kept in bindings_. @@ -71,7 +71,7 @@ class HloToIrBindings { // A helper method that returns the base pointer of the IrArray containing the // output of "inst".at the given ShapeIndex. llvm::Value* GetBasePointer(const HloInstruction& hlo, - const ShapeIndex& shape_index = {}) const { + ShapeIndexView shape_index = {}) const { auto it = base_ptrs_.find(&hlo); CHECK(it != base_ptrs_.end()) << hlo.ToString(); return it->second.element(shape_index); @@ -97,7 +97,7 @@ class HloToIrBindings { // Returns an llvm typed ir representation of 'ir_value' based on 'hlo' shape. llvm::Value* GetTypedIrValue(const HloInstruction& hlo, - const ShapeIndex& shape_index, + ShapeIndexView shape_index, llvm::Value* ir_value); const BufferAssignment* buffer_assignment_; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 36a1b82a26d84fb557c894f0bf122aef064b052e..64ed3d748febd8281a8e602194b31c937a4a682a 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -40,6 +40,7 @@ bool IsFusile(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kDynamicSlice || hlo.opcode() == HloOpcode::kDynamicUpdateSlice || hlo.opcode() == HloOpcode::kFusion || + hlo.opcode() == HloOpcode::kGather || hlo.opcode() == HloOpcode::kPad || hlo.opcode() == HloOpcode::kReduce || hlo.opcode() == HloOpcode::kReduceWindow || @@ -77,15 +78,14 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, HloInstruction* producer = consumer->mutable_operand(operand_index); // Check if we can use output fusion for (A @ B) * alpha - if (consumer->operand_count() == 2 && - (producer->opcode() == HloOpcode::kDot || - (producer->opcode() == HloOpcode::kFusion && - producer->fused_expression_root()->opcode() == HloOpcode::kDot))) { + if (producer->opcode() == HloOpcode::kDot || + (producer->opcode() == HloOpcode::kFusion && + producer->fused_expression_root()->opcode() == HloOpcode::kDot)) { int64 other_operand_index = 1 - operand_index; - const HloInstruction* alpha = consumer->operand(other_operand_index); HloInstruction* op1 = nullptr; HloInstruction* op2 = nullptr; - if (consumer->opcode() == HloOpcode::kFusion && + if (consumer->operand_count() == 1 && + consumer->opcode() == HloOpcode::kFusion && consumer->fusion_kind() == HloInstruction::FusionKind::kLoop && Match(consumer->fused_expression_root(), match::Op() @@ -103,10 +103,12 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, op2->opcode() != HloOpcode::kBroadcast) { return false; } - if (IsIEEEFloatingPointScalarConstant(alpha)) { + if (IsIEEEFloatingPointScalarConstant(op2->operand(0))) { return true; } - } else if (consumer->opcode() == HloOpcode::kMultiply) { + } else if (consumer->operand_count() == 2 && + consumer->opcode() == HloOpcode::kMultiply) { + const HloInstruction* alpha = consumer->operand(other_operand_index); // Fuse if 'alpha' is a broadcast of a scalar constant. if (alpha->opcode() == HloOpcode::kBroadcast && alpha->dimensions().empty() && @@ -173,6 +175,14 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } + // Fuse scalar constants into loop fusion nodes, this reduces the number of + // parameters and makes matching scalar broadcasts easier. + if (ShapeUtil::IsEffectiveScalar(producer->shape()) && + consumer->opcode() == HloOpcode::kFusion && + producer->opcode() == HloOpcode::kConstant) { + return true; + } + return IsFusile(*producer) && IsFusile(*consumer) && InstructionFusion::ShouldFuse(consumer, operand_index); } diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 426b1d235c3135ff61671481044beed518e2db00..1963d9eef72d41fa0a275bea98f959671fa7e737 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -168,7 +168,7 @@ TEST_F(InstructionFusionTest, BroadcastIntoReduce) { HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Fusion()); EXPECT_THAT(root->fused_expression_root(), - op::Reduce(op::Broadcast(op::Parameter()), op::Parameter())); + op::Reduce(op::Broadcast(op::Constant()), op::Constant())); } TEST_F(InstructionFusionTest, BitcastIntoAdd) { @@ -255,7 +255,7 @@ TEST_F(InstructionFusionTest, DotOutputFusion) { EXPECT_THAT( root->fused_expression_root(), op::Multiply(op::Dot(op::Parameter(), op::Transpose(op::Parameter())), - op::Broadcast(op::Parameter()))); + op::Broadcast(op::Constant()))); } // Compute sum(1/p0), where p0 has type f32, twice. Check that the division is @@ -339,7 +339,7 @@ TEST_F(InstructionFusionTest, DotOutputFusionImpossible) { EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop); EXPECT_THAT(root->fused_expression_root(), op::Multiply(op::Multiply(op::Parameter(), op::Parameter()), - op::Broadcast(op::Parameter()))); + op::Broadcast(op::Constant()))); } // Counts the HLO ops with a given op code in the specified module. @@ -581,5 +581,30 @@ TEST_F(InstructionFusionTest, FuseIntoInputFusionInstruction) { << module->ToString(); } +TEST_F(InstructionFusionTest, FuseScalarConstant) { + auto module = ParseHloString(R"( + HloModule test_module + + ENTRY FuseScalarConstant { + p0 = f32[] parameter(0) + c0 = f32[] constant(1) + add1 = f32[] add(p0, c0) + b0 = f32[2]{0} broadcast(add1), dimensions={} + c1 = f32[2]{0} constant({1, 2}) + ROOT add2 = f32[2]{0} add(b0, c1) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Add(op::Broadcast(op::Add(op::Parameter(), op::Constant())), + op::Parameter())); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 67890bfed1136796c83c7ef6912ffc1ab1b7e332..388aa35d7dceeef92dbdb6c8a3bb7fb3796a0b61 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -56,8 +56,8 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, return type_is_allowed && IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && IsRank2WithNoPadding(output_shape) && - !ShapeUtil::HasZeroElements(lhs_shape) && - !ShapeUtil::HasZeroElements(rhs_shape); + !ShapeUtil::IsZeroElementArray(lhs_shape) && + !ShapeUtil::IsZeroElementArray(rhs_shape); } bool DotImplementedAsGemm(const HloInstruction& dot) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 547af33e9a98c03e1429366172f9a401e385a9d1..d38a496fea689675f780ab5f377f4668bc9f05ca 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -478,12 +478,15 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { const Shape& lhs_shape = lhs_instruction->shape(); const Shape& rhs_shape = rhs_instruction->shape(); + // TODO(b/110211620): Convert to use i32 index_type when it is possible. + llvm::Type* index_type = ir_builder_.getInt64Ty(); + llvm_ir::IrArray::Index element_index(index_type); if (ShapeUtil::IsScalar(lhs_shape) && ShapeUtil::IsScalar(rhs_shape)) { // If the operands are scalar, don't emit any loops. llvm::Value* lhs_value = - lhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_); + lhs_array.EmitReadArrayElement(/*index=*/element_index, &ir_builder_); llvm::Value* rhs_value = - rhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_); + rhs_array.EmitReadArrayElement(/*index=*/element_index, &ir_builder_); llvm::Value* result; if (ShapeUtil::ElementIsComplex(lhs_shape)) { auto value = MultiplyComplex(lhs_value, rhs_value, &ir_builder_); @@ -493,7 +496,8 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { } else { result = ir_builder_.CreateFMul(lhs_value, rhs_value); } - target_array.EmitWriteArrayElement(/*index=*/{}, result, &ir_builder_); + target_array.EmitWriteArrayElement(/*index=*/element_index, result, + &ir_builder_); return Status::OK(); } @@ -584,7 +588,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // address. The index into the target address is the concatenation of the rhs // and lhs indexes with the reduction dimensions removed. The terms from the // rhs index are the lower dimensions in the index so we add them first. - llvm_ir::IrArray::Index target_index; + llvm_ir::IrArray::Index target_index(index_type); for (size_t dimension = 0; dimension < lhs_index.size(); ++dimension) { if (dimension != lhs_reduction_dimension) { target_index.push_back(lhs_index[dimension]); @@ -610,7 +614,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { } Status IrEmitter::HandleConvolution(HloInstruction* convolution) { - if (ShapeUtil::HasZeroElements(convolution->shape())) { + if (ShapeUtil::IsZeroElementArray(convolution->shape())) { // Emit no code for an empty output. return Status::OK(); } @@ -620,7 +624,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { } Status IrEmitter::HandleFft(HloInstruction* fft) { - if (ShapeUtil::HasZeroElements(fft->shape())) { + if (ShapeUtil::IsZeroElementArray(fft->shape())) { // Emit no code for an empty output. return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index bb47a4280541ce2806472aa9365bb0ef38c0c3b3..c9574c87a3be208915b3d6a32679553eb425d2f0 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -120,9 +120,10 @@ Status IrEmitterNested::EmitTargetElementLoop( // For MOF we give the loop emitter an array for every output it should // generate. if (hlo.IsMultiOutputFusion()) { + const int64 num_elems = ShapeUtil::TupleElementCount(hlo.shape()); std::vector target_arrays; - for (int64 i = 0, e = ShapeUtil::TupleElementCount(hlo.shape()); i != e; - ++i) { + target_arrays.reserve(num_elems); + for (int64 i = 0; i != num_elems; ++i) { target_arrays.push_back(GetIrArray(hlo, hlo, {i})); } TF_RETURN_IF_ERROR( @@ -130,6 +131,7 @@ Status IrEmitterNested::EmitTargetElementLoop( .EmitLoop()); std::vector tuple_operand_ptrs; + tuple_operand_ptrs.reserve(num_elems); for (const llvm_ir::IrArray& array : target_arrays) { tuple_operand_ptrs.push_back(array.GetBasePointer()); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index b40b557cab3e13bce2cb522d28464cecdc7c9399..f6f0a45124b9978ba21b306d0d98caaf52e8bcc0 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -59,6 +59,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/ops.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" @@ -282,6 +283,69 @@ int ComputeMaxUnrollFactor(const HloInstruction* hlo) { // Cannot unroll. return 1; } + +// Returns the llvm type for the indices used in the kernel that contains the +// hlo instruction. Such indices include the index for the parallel loop and +// the indices for the tensors accessed by the kernel. The return type is i32 +// iff the following conditions are met: +// . The launch_size of the kernel is within the range of i32. +// . The sizes of all the tensors accessed within the kernel are within the +// range of i32. +// Otherwise, the return type is i64. +llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size, + llvm::IRBuilder<>* ir_builder) { + // Find the unnested hlo instructon for which the kernel is generated for. + const HloInstruction* unnested_hlo = hlo; + const HloComputation* computation = hlo->parent(); + if (computation->IsFusionComputation()) { + unnested_hlo = computation->FusionInstruction(); + } + + auto shape_in_range = [&](const Shape& s) { + bool in_range = true; + ShapeUtil::ForEachSubshape( + s, [&](const Shape& sub_shape, const ShapeIndex& /*index*/) { + if (ShapeUtil::IsArray(sub_shape) && + !IsInt32(ShapeUtil::ElementsIn(sub_shape))) { + in_range = false; + } + }); + + return in_range; + }; + + llvm::Type* i64_ty = ir_builder->getInt64Ty(); + // Check launch dimension + if (!IsInt32(launch_size)) { + return i64_ty; + } + + // Check the size of result tensors + if (!shape_in_range(unnested_hlo->shape())) { + return i64_ty; + } + + auto hlo_shape_in_range = [&](const HloInstruction* operand) -> bool { + return shape_in_range(operand->shape()); + }; + + // Check the size of input tensors + if (!c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) { + return i64_ty; + } + + // Check the size of the internal result tensors + if (unnested_hlo->opcode() == HloOpcode::kFusion) { + if (!c_all_of( + unnested_hlo->fused_instructions_computation()->instructions(), + hlo_shape_in_range)) { + return i64_ty; + } + } + + return ir_builder->getInt32Ty(); +} + } // namespace Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { @@ -501,20 +565,27 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { case HloOpcode::kReduce: { VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString(); std::vector> thunks; - ArraySlice reduces = + ArraySlice output_instructions = root->opcode() == HloOpcode::kTuple ? root->operands() : ArraySlice(&root, 1); // For multi-output fusion emit an initializer for each tuple element. // Otherwise it's sufficient to just initialize the single output. - for (int i = 0, e = reduces.size(); i != e; ++i) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr initializer_thunk, - BuildInitializerThunk( - fusion, reduces[i] == root ? ShapeIndex() : ShapeIndex({i}))); - thunks.push_back(std::move(initializer_thunk)); + HloInstruction* first_reduce = nullptr; + for (int i = 0, e = output_instructions.size(); i != e; ++i) { + if (output_instructions[i]->opcode() == HloOpcode::kReduce) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr initializer_thunk, + BuildInitializerThunk(fusion, output_instructions[i] == root + ? ShapeIndex() + : ShapeIndex({i}))); + thunks.push_back(std::move(initializer_thunk)); + first_reduce = + first_reduce == nullptr ? output_instructions[i] : first_reduce; + } } + CHECK(first_reduce != nullptr); thunks.push_back(BuildKernelThunk(fusion)); thunk_sequence_->emplace_back( MakeUnique(std::move(thunks), fusion)); @@ -533,29 +604,47 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // fusion is a special case of that. InlinedVector input_gens; InlinedVector init_value_gens; + std::vector> + extra_output_gens; InlinedVector reducers; - for (const HloInstruction* reduce : reduces) { - CHECK_EQ(HloOpcode::kReduce, reduce->opcode()); - // TODO(kramerb): CHECK that layouts are equal. Currently this - // breaks multioutputfusion_test. The test has pre-fused - // instructions, but layout_assignment will not assign any layouts - // for instructions inside of a fused computation. It just removes - // the layouts instead. - CHECK(ShapeUtil::Compatible(reduces[0]->shape(), reduce->shape())); - CHECK(ShapeUtil::Compatible(reduces[0]->operand(0)->shape(), - reduce->operand(0)->shape())); - CHECK(ShapeUtil::Compatible(reduces[0]->operand(1)->shape(), - reduce->operand(1)->shape())); - CHECK(reduces[0]->dimensions() == reduce->dimensions()); - input_gens.push_back(fused_emitter.GetGenerator(reduce->operand(0))); - init_value_gens.push_back( - fused_emitter.GetGenerator(reduce->operand(1))); - reducers.push_back(reduce->to_apply()); + InlinedVector reduce_output_shapes; + for (int i = 0, e = output_instructions.size(); i != e; ++i) { + const HloInstruction* inst = output_instructions[i]; + ShapeIndex output_shape_index; + if (root->opcode() == HloOpcode::kTuple) { + output_shape_index = {i}; + } + if (inst->opcode() == HloOpcode::kReduce) { + // Shapes, layouts and dimensions must be the same for all reduces + // inside of this fusion. + CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape())); + CHECK(ShapeUtil::Equal(first_reduce->operand(0)->shape(), + inst->operand(0)->shape())); + CHECK(ShapeUtil::Equal(first_reduce->operand(1)->shape(), + inst->operand(1)->shape())); + CHECK(first_reduce->dimensions() == inst->dimensions()); + input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0))); + init_value_gens.push_back( + fused_emitter.GetGenerator(inst->operand(1))); + reducers.push_back(inst->to_apply()); + reduce_output_shapes.push_back(std::move(output_shape_index)); + } else { + // For extra outputs we can relax shape equality to allow different + // types (with the same number of elements). Layouts still have to + // match. + CHECK(ShapeUtil::CompatibleIgnoringElementType( + first_reduce->operand(0)->shape(), inst->shape())); + CHECK(LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(), + inst->shape().layout())); + extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst), + std::move(output_shape_index)); + } } - const Shape& input_shape = reduces[0]->operand(0)->shape(); - return EmitReductionToVector(reduces[0], input_shape, input_gens, - init_value_gens, reduces[0]->dimensions(), - reducers); + const Shape& input_shape = first_reduce->operand(0)->shape(); + return EmitReductionToVector(first_reduce, input_shape, input_gens, + init_value_gens, + first_reduce->dimensions(), reducers, + reduce_output_shapes, extra_output_gens); } default: LOG(FATAL) << "Bad opcode for input fusion: " @@ -940,11 +1029,33 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { return IrEmitter::HandleCopy(copy); } +Status IrEmitterUnnested::EmitExtraOutputsForReduce( + const HloInstruction* reduce, const llvm_ir::IrArray::Index& index, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { + for (int i = 0; i != extra_output_gens.size(); ++i) { + const HloInstruction* output = reduce->parent()->FusionInstruction(); + llvm::Value* extra_output_address = + GetIrArray(*output, *output, extra_output_gens[i].second) + .EmitArrayElementAddress(index, &ir_builder_, + "extra_output_element_address"); + TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, + extra_output_gens[i].first(index)); + ir_builder_.CreateStore(extra_output_ir_value, extra_output_address); + } + return Status::OK(); +} + Status IrEmitterUnnested::EmitReductionToScalar( HloInstruction* reduce, const Shape& input_shape, tensorflow::gtl::ArraySlice input_gens, tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers) { + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { // Number of elements processed by a single thread. constexpr int64 kTileSize = 16; int64 num_elems = ShapeUtil::ElementsIn(input_shape); @@ -956,6 +1067,20 @@ Status IrEmitterUnnested::EmitReductionToScalar( int64 num_tiles = RoundUpToNearest(CeilOfRatio(num_elems, kTileSize), kWarpSize); + Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( + reduce->shape().element_type(), {num_tiles}, {0}); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + tiled_input_shape, ir_emitter_context_->device_description()); + + llvm::Type* index_ty = GetIndexTypeForKernel( + reduce, + launch_dimensions.block_count() * launch_dimensions.threads_per_block(), + &ir_builder_); + + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + // Check whether every thread will process a full tile's worth of elements // without reading outside the bounds of the input. If this is true, we can // skip some bounds checks in the final algorithm. @@ -1004,40 +1129,42 @@ Status IrEmitterUnnested::EmitReductionToScalar( llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result." + llvm::Twine(i)); - TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, - init_value_gens[i](llvm_ir::IrArray::Index({}))); + TF_ASSIGN_OR_RETURN( + llvm::Value* const init_ir_value, + init_value_gens[i](llvm_ir::IrArray::Index(index_ty))); ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); } llvm::Value* x_in_tiles = tile_index[0]; + x_in_tiles = ir_builder_.CreateZExtOrTrunc(x_in_tiles, index_ty); // Emit an inner for-loop that reduces the elements in the tile. auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { std::unique_ptr tile_element_loop = - llvm_ir::ForLoop::EmitForLoop("element_id_in_tile", - ir_builder_.getInt64(0), - ir_builder_.getInt64(kTileSize), - ir_builder_.getInt64(1), &ir_builder_); + llvm_ir::ForLoop::EmitForLoop( + "element_id_in_tile", index_typed_const(0), + index_typed_const(kTileSize), index_typed_const(1), &ir_builder_); // Emit the body of the partial reduction loop. llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), &ir_builder_); llvm::Value* x = ir_builder_.CreateNSWAdd( - ir_builder_.CreateNSWMul(x_in_tiles, ir_builder_.getInt64(kTileSize)), + ir_builder_.CreateNSWMul(x_in_tiles, index_typed_const(kTileSize)), tile_element_loop->GetIndVarValue()); // Unless we know the tile is entirely in bounds, we have to emit a // x-in-bounds check before reading from the input. if (!tile_in_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpULT(x, ir_builder_.getInt64(num_elems)), + ir_builder_.CreateICmpULT(x, index_typed_const(num_elems)), "x_in_bounds", &ir_builder_); // Emit code that reads the input element and accumulates it to // the partial reduction result. llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_); } + llvm_ir::IrArray::Index input_index( /*linear=*/x, input_shape, &ir_builder_); llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); @@ -1050,18 +1177,18 @@ Status IrEmitterUnnested::EmitReductionToScalar( {partial_reduction_result_addresses[i], input_address}, partial_reduction_result_addresses[i])); } - return Status::OK(); + return EmitExtraOutputsForReduce(reduce, input_index, extra_output_gens); }; // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's // immediately beyond the tile. llvm::Value* x_end = ir_builder_.CreateNSWAdd( - ir_builder_.getInt64(kTileSize), - ir_builder_.CreateNSWMul(x_in_tiles, ir_builder_.getInt64(kTileSize))); + index_typed_const(kTileSize), + ir_builder_.CreateNSWMul(x_in_tiles, index_typed_const(kTileSize))); // The tile is entirely in bound if all_threads_in_bounds or // x_end <= num_elems. llvm::Value* tile_in_bounds = ir_builder_.CreateOr( - ir_builder_.CreateICmpULE(x_end, ir_builder_.getInt64(num_elems)), + ir_builder_.CreateICmpULE(x_end, index_typed_const(num_elems)), ir_builder_.getInt1(all_threads_in_bounds)); llvm_ir::LlvmIfData if_tile_in_bounds_data = llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_); @@ -1112,25 +1239,21 @@ Status IrEmitterUnnested::EmitReductionToScalar( // lane 0 (which holds the partially accumulated result for its warp) to the // output element. llvm::Value* lane_id = ir_builder_.CreateURem( - x_in_tiles, ir_builder_.getInt64(kWarpSize), "lane_id"); + x_in_tiles, index_typed_const(kWarpSize), "lane_id"); llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpEQ(lane_id, ir_builder_.getInt64(0)), + ir_builder_.CreateICmpEQ(lane_id, index_typed_const(0)), "lane_id_is_zero", &ir_builder_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &ir_builder_); for (int i = 0; i != num_reduces; ++i) { - ShapeIndex output_shape_index; - if (output->IsMultiOutputFusion()) { - output_shape_index = {i}; - } llvm::Value* output_address = - GetIrArray(*output, *output, output_shape_index) + GetIrArray(*output, *output, reduce_output_shapes[i]) .EmitArrayElementAddress( llvm_ir::IrArray::Index( /*linear=*/ir_builder_.getInt64(0), ShapeUtil::GetSubshape(output->shape(), - output_shape_index), + reduce_output_shapes[i]), &ir_builder_), &ir_builder_, "output_element_address"); TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( @@ -1140,10 +1263,6 @@ Status IrEmitterUnnested::EmitReductionToScalar( }; // Emit a parallel loop that iterates through all input tiles, one per thread. - Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), {num_tiles}, {0}); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - tiled_input_shape, ir_emitter_context_->device_description()); CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); UpdateLaunchDimensions( launch_dimensions, @@ -1151,14 +1270,18 @@ Status IrEmitterUnnested::EmitReductionToScalar( ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, launch_dimensions, &ir_builder_) - .EmitLoop(IrName(reduce)); + .EmitLoop(IrName(reduce), index_ty); } Status IrEmitterUnnested::EmitColumnReduction( int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, tensorflow::gtl::ArraySlice input_gens, tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers) { + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { // Divide the input matrix into tiles of size Kx1. For example, when the // input matrix is 4x4 and K=2, the tiled matrix looks like // @@ -1178,6 +1301,17 @@ Status IrEmitterUnnested::EmitColumnReduction( // If the height is not a multiple of the tile size, we pad the bottom of the // input matrix. const int64 height_in_tiles = CeilOfRatio(height, kTileSize); + Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( + reduce->shape().element_type(), {height_in_tiles, width}, {1, 0}); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + tiled_input_shape, ir_emitter_context_->device_description()); + + // TODO(b/110211620): Convert to use i32 index_type when it is possible. + llvm::Type* index_ty = ir_builder_.getInt64Ty(); + + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; // linear_index < height_in_tiles * width; @@ -1213,8 +1347,9 @@ Status IrEmitterUnnested::EmitColumnReduction( llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result." + llvm::Twine(i)); - TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, - init_value_gens[i](llvm_ir::IrArray::Index({}))); + TF_ASSIGN_OR_RETURN( + llvm::Value* const init_ir_value, + init_value_gens[i](llvm_ir::IrArray::Index(index_ty))); ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); @@ -1225,24 +1360,27 @@ Status IrEmitterUnnested::EmitColumnReduction( llvm::Value* y_in_tiles = tile_index[0]; llvm::Value* x = tile_index[1]; + y_in_tiles = ir_builder_.CreateZExtOrTrunc(y_in_tiles, index_ty); + x = ir_builder_.CreateZExtOrTrunc(x, index_ty); + auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { std::unique_ptr tile_element_loop = - llvm_ir::ForLoop::EmitForLoop("element_id_in_tile", - ir_builder_.getInt64(0), - ir_builder_.getInt64(kTileSize), - ir_builder_.getInt64(1), &ir_builder_); + llvm_ir::ForLoop::EmitForLoop( + "element_id_in_tile", index_typed_const(0), + index_typed_const(kTileSize), index_typed_const(1), &ir_builder_); // Emit the body of the partial reduction loop. llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), &ir_builder_); llvm::Value* y = ir_builder_.CreateNSWAdd( - ir_builder_.CreateNSWMul(y_in_tiles, ir_builder_.getInt64(kTileSize)), + ir_builder_.CreateNSWMul(y_in_tiles, index_typed_const(kTileSize)), tile_element_loop->GetIndVarValue()); + // Unless we know the tile is entirely in bounds, we have to emit a // y-in-bounds check before reading from the input. if (!tile_in_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpULT(y, ir_builder_.getInt64(height)), + ir_builder_.CreateICmpULT(y, index_typed_const(height)), "y_in_bounds", &ir_builder_); // Emit code that reads the input element and accumulates it to @@ -1284,17 +1422,18 @@ Status IrEmitterUnnested::EmitColumnReduction( {partial_reduction_result_addresses[i], input_address}, partial_reduction_result_addresses[i])); } - return Status::OK(); + return EmitExtraOutputsForReduce(reduce, input_index, + extra_output_gens); } }; // y_end = kTileSize + y_in_tiles * kTileSize, i.e., the y location that's // immediately beyond the tile. llvm::Value* y_end = ir_builder_.CreateNSWAdd( - ir_builder_.getInt64(kTileSize), - ir_builder_.CreateNSWMul(y_in_tiles, ir_builder_.getInt64(kTileSize))); + index_typed_const(kTileSize), + ir_builder_.CreateNSWMul(y_in_tiles, index_typed_const(kTileSize))); llvm::Value* tile_in_bounds = ir_builder_.CreateOr( - ir_builder_.CreateICmpULE(y_end, ir_builder_.getInt64(height)), + ir_builder_.CreateICmpULE(y_end, index_typed_const(height)), ir_builder_.getInt1(height % kTileSize == 0)); // The tile is entirely in bound if "height" is a multiple of kTileSize or // y_end <= height. @@ -1315,17 +1454,13 @@ Status IrEmitterUnnested::EmitColumnReduction( const HloInstruction* output = reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; for (int i = 0; i != num_reduces; ++i) { - ShapeIndex output_shape_index; - if (output->IsMultiOutputFusion()) { - output_shape_index = {i}; - } llvm::Value* output_address = - GetIrArray(*output, *output, output_shape_index) + GetIrArray(*output, *output, reduce_output_shapes[i]) .EmitArrayElementAddress( llvm_ir::IrArray::Index( x, ShapeUtil::GetSubshape(output->shape(), - output_shape_index), + reduce_output_shapes[i]), &ir_builder_), &ir_builder_, "output_element_address"); TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( @@ -1335,10 +1470,6 @@ Status IrEmitterUnnested::EmitColumnReduction( }; // Emit a parallel loop that iterate through all input tiles. - Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), {height_in_tiles, width}, {1, 0}); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - tiled_input_shape, ir_emitter_context_->device_description()); CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); UpdateLaunchDimensions( launch_dimensions, @@ -1346,7 +1477,31 @@ Status IrEmitterUnnested::EmitColumnReduction( ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, launch_dimensions, &ir_builder_) - .EmitLoop(IrName(reduce)); + .EmitLoop(IrName(reduce), index_ty); +} + +static std::pair ComputeTilingSchemeForReduction( + int64 depth, int64 width, int64 kWarpSize) { + constexpr int64 kTargetNumElementsPerThread = 64; + int64 x_tile_size = kTargetNumElementsPerThread; + int64 z_tile_size = 1; + + // Only tile along the x dimension with tile size kTargetNumElementsPerThread + // if doing so doesn't require a slow version of loop with bound check on each + // dimension. A more sophisticated heuristics is to enable tile along the + // x dimension with tile size kTargetNumElementsPerThread when either width is + // a factor of (kWarpSize * kTargetNumElementsPerThread) or width is big + // enough so that only a small fraction of the threads execute the slow + // version of loop with bound check. + if (width % (kWarpSize * kTargetNumElementsPerThread) != 0) { + x_tile_size = 8; + z_tile_size = 8; + while (depth % z_tile_size != 0) { + z_tile_size -= 1; + } + } + + return std::pair(x_tile_size, z_tile_size); } Status IrEmitterUnnested::EmitRowReduction( @@ -1354,9 +1509,13 @@ Status IrEmitterUnnested::EmitRowReduction( const Shape& input_shape, tensorflow::gtl::ArraySlice input_gens, tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers) { + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { // A naive algorithm is: - // 1. Divide the input tensor into tiles of size 1x1xK. + // 1. Divide the x dimension of the input tensor into tiles of size 1x1xX. // 2. Partially reduces each tile to a scalar using one thread. // 3. Accumulates that scalar to the output vector using atomic operations. // @@ -1367,15 +1526,15 @@ Status IrEmitterUnnested::EmitRowReduction( // int y = linear_index / width_in_tiles % height; // int z = linear_index / (height * width_in_tiles); // float partial_result = 0; - // for (element_id_in_tile : range(kTileSize)) { - // int x = x_in_tiles * kTileSize + element_id_in_tile; + // for (element_id_in_tile : range(x_tile_size)) { + // int x = x_in_tiles * x_tile_size + element_id_in_tile; // if (x < width) // partial_result = reducer(partial_result, input[z][y][z]); // } // AtomicReducer(&output[y], partial_result); // } // - // Three optimizations are performed. + // Four optimizations are performed. // // 1. To coalesce global memory accesses, dilate the tile with a factor of 32 // (i.e. the warp size). For example, suppose the width is 8x32=256. Instead @@ -1402,29 +1561,44 @@ Status IrEmitterUnnested::EmitRowReduction( // element_id_in_tile, which makes the code more friendly to optimizations // such as LICM. // + // 4. When the width is too small and x_tile_size is less than the target + // number of elements per thread and use a small factor of depth as + // z_tile_size to increase the number of elements calculated by each + // partial sum. This can reduce the needed number of dynamic shfl_down and + // atomic operations. + // // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; // linear_index < depth * height * width_in_tiles; // linear_index += blockDim.x * gridDim.x) { // int x_in_tiles = linear_index % width_in_tiles; // int y = linear_index / width_in_tiles % height; - // int z = linear_index / (height * width_in_tiles); + // int z_in_tiles = linear_index / (height * width_in_tiles); // int warp_id = x_in_tiles / warpSize; // int lane_id = x_in_tiles % warpSize; // float partial_result = 0; // int x = warp_id * kTileSize * warpSize + lane_id; - // if (width % (kTileSize * warpSize) == 0 || - // x + (kTileSize - 1) * warpSize < width) { - // // The entire tile is in bounds. - // for (int element_id_in_tile = 0; element_id_in_tile < kTileSize; - // ++element_id_in_tile, x += warpSize) { - // partial_result = Reducer(partial_result, input[z][y][x]); + // if (width % (x_tile_size * warpSize) == 0 || + // x + (x_tile_size - 1) * warpSize < width) { + // // The entire x_tile is in bounds. + // for (int element_id_in_z_tile = 0; element_id_in_z_tile < z_tile_size; + // ++element_id_in_z_tile) { + // z = z_in_tiles * z_tile_size + element_id_in_z_tile; + // for (int element_id_in_x_tile = 0; + // element_id_in_x_tile < x_tile_size; + // ++element_id_in_x_tile, x += warpSize) { + // partial_result = Reducer(partial_result, input[z][y][x]); + // } // } // } else { // // The tile is partially in bounds. - // for (int element_id_in_tile = 0; element_id_in_tile < kTileSize; - // ++element_id_in_tile, x += warpSize) { - // if (x < width) - // partial_result = Reducer(partial_result, input[z][y][x]); + // for (int element_id_in_z_tile = 0; element_id_in_z_tile < z_tile_size; + // ++element_id_in_z_tile) { + // z = z_in_tiles * z_tile_size + element_id_in_z_tile; + // for (int element_id_in_x_tile = 0; element_id_in_x_tile < + // x_tile_size; ++element_id_in_tile, x += warpSize) { + // if (x < width) + // partial_result = Reducer(partial_result, input[z][y][x]); + // } // } // } // for (shuffle_distance = 16; shuffle_distance > 0; shuffle_distance /= 2) @@ -1435,17 +1609,32 @@ Status IrEmitterUnnested::EmitRowReduction( // AtomicReducer(&output[y], partial_result); // } // - // Choose 8 as the tile size, which matches Eigen's RowReduceKernel. - constexpr int64 kTileSize = 8; + + int64 x_tile_size; + int64 z_tile_size; + std::tie(x_tile_size, z_tile_size) = + ComputeTilingSchemeForReduction(depth, width, kWarpSize); + // Round the width in tiles up to the nearest multiple of kWarpSize, so that // the use of shfl_down is valid. const int64 width_in_tiles = - RoundUpToNearest(CeilOfRatio(width, kTileSize), kWarpSize); + RoundUpToNearest(CeilOfRatio(width, x_tile_size), kWarpSize); + Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( + reduce->shape().element_type(), + {depth / z_tile_size, height, width_in_tiles}, {2, 1, 0}); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + tiled_input_shape, ir_emitter_context_->device_description()); + llvm::Type* index_ty = GetIndexTypeForKernel( + reduce, + launch_dimensions.block_count() * launch_dimensions.threads_per_block(), + &ir_builder_); - auto loop_body_emitter = - [=](const llvm_ir::IrArray::Index& tile_index) -> Status { + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_ty, c); + }; + + auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) { const int num_reduces = reducers.size(); - // Emit the loop body that reduces one tile. llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType( input_shape.element_type(), ir_emitter_context_->llvm_module()); std::vector partial_reduction_result_addresses; @@ -1453,123 +1642,151 @@ Status IrEmitterUnnested::EmitRowReduction( llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result." + llvm::Twine(i)); - TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, - init_value_gens[i](llvm_ir::IrArray::Index({}))); + TF_ASSIGN_OR_RETURN( + llvm::Value* const init_ir_value, + init_value_gens[i](llvm_ir::IrArray::Index(index_ty))); ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); } - // Emit an inner for-loop that partially reduces the elements in the given - // tile. - llvm::Value* z = tile_index[0]; + llvm::Value* z_tile = tile_index[0]; llvm::Value* y = tile_index[1]; llvm::Value* x_tile = tile_index[2]; - llvm::Value* warp_id = ir_builder_.CreateUDiv( - x_tile, ir_builder_.getInt64(kWarpSize), "warp_id"); - llvm::Value* lane_id = ir_builder_.CreateURem( - x_tile, ir_builder_.getInt64(kWarpSize), "lane_id"); - // The x-location of the last element in this tile. - // last_x = lane_id + warpSize * (kTileSize - 1 + warp_id * kTileSize); - llvm::Value* last_x = ir_builder_.CreateNSWAdd( - lane_id, - ir_builder_.CreateNSWMul( - ir_builder_.getInt64(kWarpSize), - ir_builder_.CreateNSWAdd( - ir_builder_.getInt64(kTileSize - 1), - ir_builder_.CreateNSWMul(warp_id, - ir_builder_.getInt64(kTileSize))))); + x_tile = ir_builder_.CreateZExtOrTrunc(x_tile, index_ty); - auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { - std::unique_ptr tile_element_loop = - llvm_ir::ForLoop::EmitForLoop("element_id_in_tile", - ir_builder_.getInt64(0), - ir_builder_.getInt64(kTileSize), - ir_builder_.getInt64(1), &ir_builder_); + llvm::Value* warp_id = + ir_builder_.CreateUDiv(x_tile, index_typed_const(kWarpSize), "warp_id"); + llvm::Value* lane_id = + ir_builder_.CreateURem(x_tile, index_typed_const(kWarpSize), "lane_id"); - // Emit the body of the partial reduction loop. - llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), - &ir_builder_); - // x = lane_id + warpSize * (element_id_in_tile + warp_id * kTileSize); - llvm::Value* x = ir_builder_.CreateNSWAdd( - lane_id, - ir_builder_.CreateNSWMul( - ir_builder_.getInt64(kWarpSize), - ir_builder_.CreateNSWAdd( - tile_element_loop->GetIndVarValue(), - ir_builder_.CreateNSWMul(warp_id, - ir_builder_.getInt64(kTileSize))))); - - // Unless we know the tile is entirely in bounds, we have to emit a - // x-in-bounds check before reading from the input. - if (!tile_in_bounds) { - llvm_ir::LlvmIfData if_x_in_bounds_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpULT(x, ir_builder_.getInt64(width)), - "x_in_bounds", &ir_builder_); - - // Points ir_builder_ to the then-block. - llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, - &ir_builder_); - } - - // Emit code that reads the input element and accumulates it to the - // partial reduction result. - llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); - { - // {z,y,x} is an index to input_3d_tensor_shape [depth,height,width]. We - // need to convert that to an index to input_shape (the shape of the - // operand of "reduce"). This conversion is composed of a transposition - // from input_shape to normalized_input_shape and a reshape from - // normalized_input_shape to input_3d_tensor_shape. - const Shape normalized_input_shape = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - input_shape); - auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape); - const std::vector transpose_dimension_mapping( - input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); - const Shape input_3d_tensor_shape = - ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(), - {depth, height, width}); - const llvm_ir::IrArray::Index input_3d_tensor_index( - {z, y, x}, input_3d_tensor_shape, &ir_builder_); - const llvm_ir::IrArray::Index input_index = - input_3d_tensor_index - .SourceIndexOfReshape(input_3d_tensor_shape, - normalized_input_shape, &ir_builder_) - .SourceIndexOfTranspose(normalized_input_shape, input_shape, - transpose_dimension_mapping, - &ir_builder_); - for (int i = 0; i != num_reduces; ++i) { - TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, - input_gens[i](input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducers[i], - {partial_reduction_result_addresses[i], input_address}, - partial_reduction_result_addresses[i])); - } + // The x-location of the last element in this z-x-tile. + // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size); + llvm::Value* last_x = ir_builder_.CreateNSWAdd( + lane_id, ir_builder_.CreateNSWMul( + index_typed_const(kWarpSize), + ir_builder_.CreateNSWAdd( + index_typed_const(x_tile_size - 1), + ir_builder_.CreateNSWMul( + warp_id, index_typed_const(x_tile_size))))); + + KernelSupportLibrary ksl( + &ir_builder_, + /*unroll_mode=*/xla::llvm_ir::UnrollMode::kFullyUnroll, + /*prevent_vectorization=*/false); + + // Emit a for-loop that partially reduces the elements in the given + // z-x-tile. + auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds, + int64 x_tile_loop_bound) -> Status { + auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status { + llvm::Value* z = ir_builder_.CreateNSWAdd( + z_indvar, + ir_builder_.CreateNSWMul(index_typed_const(z_tile_size), z_tile)); + TF_RETURN_IF_ERROR(ksl.For( + "x_tile", + /*start=*/index_typed_const(0), + /*end=*/index_typed_const(x_tile_loop_bound), + /*step=*/1, [&](llvm::Value* x_indvar) -> Status { + // x = lane_id + + // warpSize * (element_id_in_x_tile + warp_id * x_tile_size); + llvm::Value* x = ir_builder_.CreateNSWAdd( + lane_id, + ir_builder_.CreateNSWMul( + index_typed_const(kWarpSize), + ir_builder_.CreateNSWAdd( + x_indvar, ir_builder_.CreateNSWMul( + warp_id, llvm::ConstantInt::get( + index_ty, x_tile_size))))); + + // Unless we know the x-tile is entirely in bounds, we have to + // emit a x-in-bounds check before reading from the input. + if (!x_tile_in_bounds) { + llvm_ir::LlvmIfData if_x_in_bounds_data = + llvm_ir::EmitIfThenElse( + ir_builder_.CreateICmpULT(x, index_typed_const(width)), + "x_in_bounds", &ir_builder_); + // Points ir_builder_ to the then-block. + llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, + &ir_builder_); + } + + // Emit code that reads the input element and accumulates it + // to the partial reduction result. + llvm::Value* input_address = + ir_builder_.CreateAlloca(element_ir_type); + { + // {z,y,x} is an index to input_3d_tensor_shape + // [depth,height,width]. We need to convert that to an index + // to input_shape (the shape of the operand of "reduce"). + // This conversion is composed of a transposition from + // input_shape to normalized_input_shape and a reshape from + // normalized_input_shape to input_3d_tensor_shape. + const Shape normalized_input_shape = ShapeUtil:: + MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + input_shape); + auto input_shape_min2maj = + LayoutUtil::MinorToMajor(input_shape); + const std::vector transpose_dimension_mapping( + input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); + const Shape input_3d_tensor_shape = + ShapeUtil::MakeShapeWithDescendingLayout( + input_shape.element_type(), {depth, height, width}); + const llvm_ir::IrArray::Index input_3d_tensor_index( + {z, y, x}, input_3d_tensor_shape, &ir_builder_); + const llvm_ir::IrArray::Index input_index = + input_3d_tensor_index + .SourceIndexOfReshape(input_3d_tensor_shape, + normalized_input_shape, + &ir_builder_) + .SourceIndexOfTranspose( + normalized_input_shape, input_shape, + transpose_dimension_mapping, &ir_builder_); + + for (int i = 0; i != num_reduces; ++i) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, + input_gens[i](input_index)); + ir_builder_.CreateStore(input_ir_value, input_address); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], input_address}, + partial_reduction_result_addresses[i])); + } + return EmitExtraOutputsForReduce(reduce, input_index, + extra_output_gens); + } + })); return Status::OK(); - } + }; + + return ksl.For("z_tile", + /*start=*/index_typed_const(0), + /*end=*/index_typed_const(z_tile_size), + /*step=*/1, emit_z_tile_element_loop); }; llvm::Value* tile_in_bounds = ir_builder_.CreateOr( - ir_builder_.getInt1(width % (kTileSize * kWarpSize) == 0), - ir_builder_.CreateICmpULT(last_x, ir_builder_.getInt64(width))); - llvm_ir::LlvmIfData if_tile_in_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, - &ir_builder_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, - &ir_builder_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); - - // After the if-then-else statement on tile_in_bounds, emit calls to - // shfl_down that accumulate the partial reduction results of all threads - // from the warp. - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, - &ir_builder_); + ir_builder_.getInt1(width % (x_tile_size * kWarpSize) == 0), + ir_builder_.CreateICmpULT(last_x, index_typed_const(width))); + + TF_RETURN_IF_ERROR( + ksl.If(tile_in_bounds, + /*true_block_generator=*/ + [&]() -> Status { + return emit_z_x_tile_element_loop(/*x_tile_in_bounds=*/true, + x_tile_size); + }, + /*false_block_generator=*/ + [&]() -> Status { + return emit_z_x_tile_element_loop( + /*x_tile_in_bounds=*/false, + CeilOfRatio(width % (x_tile_size * kWarpSize), kWarpSize)); + })); + + // After accumulating the elements of the z_x_tile, emit calls to + // shfl_down that accumulate the partial reduction results of all + // threads in a warp. int bit_width = llvm_ir::GetSizeInBits(element_ir_type); // bitcast cannot be applied to aggregate types (even packed ones), so we // instead bitcast addresses of load/store to intN* of the same bit-width. @@ -1605,36 +1822,35 @@ Status IrEmitterUnnested::EmitRowReduction( // lane 0 (which holds the partially accumulated result for its warp) to the // output element. llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpEQ(lane_id, ir_builder_.getInt64(0)), + ir_builder_.CreateICmpEQ(lane_id, index_typed_const(0)), "lane_id_is_zero", &ir_builder_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &ir_builder_); for (int i = 0; i != num_reduces; ++i) { - ShapeIndex output_shape_index; - if (output->IsMultiOutputFusion()) { - output_shape_index = {i}; - } llvm::Value* output_address = - GetIrArray(*output, *output, output_shape_index) + GetIrArray(*output, *output, reduce_output_shapes[i]) .EmitArrayElementAddress( llvm_ir::IrArray::Index( y, ShapeUtil::GetSubshape(output->shape(), - output_shape_index), + reduce_output_shapes[i]), &ir_builder_), &ir_builder_, "output_element_address"); - TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, partial_reduction_result_addresses[i])); + if (x_tile_size * z_tile_size < depth * width) { + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, + partial_reduction_result_addresses[i])); + } else { + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {output_address, partial_reduction_result_addresses[i]}, + output_address)); + } } return Status::OK(); }; // Emit a parallel loop that iterates through every input tiles. - Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), {depth, height, width_in_tiles}, - {2, 1, 0}); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - tiled_input_shape, ir_emitter_context_->device_description()); CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); UpdateLaunchDimensions( launch_dimensions, @@ -1642,7 +1858,7 @@ Status IrEmitterUnnested::EmitRowReduction( ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, launch_dimensions, &ir_builder_) - .EmitLoop(IrName(reduce)); + .EmitLoop(IrName(reduce), index_ty); } // Figures out whether `reduce` is a row or column reduction, and which @@ -1656,7 +1872,11 @@ Status IrEmitterUnnested::EmitReductionToVector( tensorflow::gtl::ArraySlice input_gens, tensorflow::gtl::ArraySlice init_value_gens, tensorflow::gtl::ArraySlice dimensions_to_reduce, - tensorflow::gtl::ArraySlice reducers) { + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { // This emission requires "reduce" to have an input layout. It is either set // by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for // a fused kReduce). @@ -1692,7 +1912,8 @@ Status IrEmitterUnnested::EmitReductionToVector( // dimension of the input is to keep. if (input_dims_to_keep.empty()) { return EmitReductionToScalar(reduce, input_shape, input_gens, - init_value_gens, reducers); + init_value_gens, reducers, + reduce_output_shapes, extra_output_gens); } else if (input_dims_to_keep.front() == LayoutUtil::Minor(input_shape.layout(), 0)) { // Column reduction. Treat the result of "input" as a matrix whose width @@ -1710,7 +1931,8 @@ Status IrEmitterUnnested::EmitReductionToVector( } } return EmitColumnReduction(height, width, reduce, input_shape, input_gens, - init_value_gens, reducers); + init_value_gens, reducers, reduce_output_shapes, + extra_output_gens); } else { // Reduce the row dimension of a matrix or reduce dimension 0 and 2 in a // 3D tensor. The size of dimension 1 (the height) is the size of the @@ -1736,7 +1958,8 @@ Status IrEmitterUnnested::EmitReductionToVector( } const int64 height = ShapeUtil::ElementsIn(reduce->shape()); return EmitRowReduction(depth, height, width, reduce, input_shape, - input_gens, init_value_gens, reducers); + input_gens, init_value_gens, reducers, + reduce_output_shapes, extra_output_gens); } } @@ -1747,7 +1970,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { HloComputation* reducer = reduce->to_apply(); // HandleReduce specializes reduction from a multi-dimensional array to a 1D // array. The specialized version requires an initializer thunk that - // initializes the output array to the initial value of the reduce. + // ingitializes the output array to the initial value of the reduce. if (IsReductionToVector(*reduce) && // NVPTX backend can't do atomic cmpxchg any narrower than 32 bits 32 <= primitive_util::BitWidth(reduce->shape().element_type())) { @@ -1768,7 +1991,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { return GetIrArray(*init_value, *reduce) .EmitReadArrayElement(index, &ir_builder_); }}, - dimensions_to_reduce, {reducer}); + dimensions_to_reduce, {reducer}, {{}}, {}); } thunk_sequence_->emplace_back(BuildKernelThunk(reduce)); @@ -1835,6 +2058,14 @@ Status IrEmitterUnnested::HandleSelectAndScatter( "Dilation for SelectAndScatter not implemented on GPU."); } + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + source->shape(), ir_emitter_context_->device_description()); + llvm::Type* index_type = GetIndexTypeForKernel( + select_and_scatter, launch_dimensions.launch_bound(), &ir_builder_); + auto index_typed_const = [&](uint64 c) -> llvm::Constant* { + return llvm::ConstantInt::get(index_type, c); + }; + // kSelectAndScatter is implemented as two kernel launches: the first launch // initializes the output array to the given initial value, // and the second accumulates the "source" matrix to the @@ -1865,8 +2096,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( "selected_value_address", &ir_builder_); llvm::Value* selected_index_address = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - ir_builder_.getInt64Ty(), ir_builder_.getInt32(rank), - "selected_index_address", &ir_builder_); + index_type, index_typed_const(rank), "selected_index_address", + &ir_builder_); llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry( ir_builder_.getInt1Ty(), "initialized_flag_address", &ir_builder_); ir_builder_.CreateStore(ir_builder_.getInt1(false), @@ -1874,7 +2105,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // Create the inner loop to iterate over the window. llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), - &ir_builder_); + &ir_builder_, index_type); std::vector window_size; for (const auto& dim : window.dimensions()) { window_size.push_back(dim.size()); @@ -1888,17 +2119,17 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // Compute the operand index to visit and evaluate the condition whether the // operand index is within the bounds. The unsigned comparison includes // checking whether the operand index >= 0. - llvm_ir::IrArray::Index operand_index(source_index.size()); + llvm_ir::IrArray::Index operand_index(index_type, source_index.size()); llvm::Value* in_bounds_condition = ir_builder_.getInt1(true); for (int64 i = 0; i < rank; ++i) { llvm::Value* strided_index = ir_builder_.CreateNSWMul( - source_index[i], ir_builder_.getInt64(window.dimensions(i).stride())); + source_index[i], index_typed_const(window.dimensions(i).stride())); operand_index[i] = ir_builder_.CreateNSWSub( ir_builder_.CreateNSWAdd(strided_index, window_index[i]), - ir_builder_.getInt64(window.dimensions(i).padding_low())); + index_typed_const(window.dimensions(i).padding_low())); llvm::Value* index_condition = ir_builder_.CreateICmpULT( operand_index[i], - ir_builder_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); + index_typed_const(ShapeUtil::GetDimension(operand->shape(), i))); in_bounds_condition = ir_builder_.CreateAnd(in_bounds_condition, index_condition); } @@ -1970,7 +2201,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // value and the current output value. llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(), &ir_builder_); - llvm_ir::IrArray::Index selected_index; + llvm_ir::IrArray::Index selected_index(operand_index.GetType()); for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP( selected_index_address, {ir_builder_.getInt32(i)}); @@ -1988,8 +2219,6 @@ Status IrEmitterUnnested::HandleSelectAndScatter( source_value_address); }; - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - source->shape(), ir_emitter_context_->device_description()); UpdateLaunchDimensions( launch_dimensions, // IrEmitterUnnested implements kSelectAndScatter as a SequentialThunk @@ -2000,7 +2229,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( ir_emitter_context_->llvm_module()); return ParallelLoopEmitter(loop_body_emitter, source->shape(), launch_dimensions, &ir_builder_) - .EmitLoop(IrName(select_and_scatter)); + .EmitLoop(IrName(select_and_scatter), index_type); } Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) { @@ -2082,6 +2311,10 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { return Status::OK(); } +Status IrEmitterUnnested::HandleGenerateToken(HloInstruction* gen_token) { + return Status::OK(); +} + Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) { thunk_sequence_->emplace_back(BuildInfeedThunk(infeed)); return Status::OK(); @@ -2205,11 +2438,6 @@ GetHloBufferSlices(const HloInstruction* hlo, return slices; } -Status IrEmitterUnnested::HandleGather(HloInstruction* gather) { - // TODO(b/72710576): Gather is not implemented on GPUs - return Unimplemented("Gather is not implemented on GPUs."); -} - std::unique_ptr IrEmitterUnnested::BuildKernelThunk( const HloInstruction* inst, int unroll_factor) { const BufferAssignment& buffer_assn = @@ -2390,7 +2618,9 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( if (alpha->opcode() == HloOpcode::kBroadcast) { alpha = alpha->operand(0); } - alpha = inst->operand(alpha->parameter_number()); + if (alpha->opcode() == HloOpcode::kParameter) { + alpha = inst->operand(alpha->parameter_number()); + } // TODO(b/74185543): Remove the following if block once we support fusion // with a non-constant as well. Then we will just always use the constant // on the device. @@ -2436,7 +2666,7 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( const HloInstruction* hlo, const ShapeIndex& index) { bool fused = HloOpcode::kFusion == hlo->opcode(); const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo; - const HloInstruction* init_value = [&] { + const HloInstruction* init_value_operand = [&] { switch (inst->opcode()) { case HloOpcode::kSelectAndScatter: return inst->operand(2); @@ -2456,6 +2686,7 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( } }(); + const HloInstruction* init_value = init_value_operand; if (fused && init_value->opcode() == HloOpcode::kParameter) { init_value = hlo->operand(init_value->parameter_number()); } @@ -2479,8 +2710,9 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by // repeating the literal 4 or 2 times, so long as the destination buffer is // an even multiple of 32 bits long. + const Shape& output_shape = ShapeUtil::GetSubshape(hlo->shape(), index); if ((num_bytes == 1 || num_bytes == 2) && - ShapeUtil::ByteSizeOf(hlo->shape()) % 4 == 0) { + ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) { uint16 pattern16; if (num_bytes == 1) { uint8 b = literal_bytes.front(); @@ -2507,13 +2739,24 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( // Otherwise fall back to our slow initializer code. std::unique_ptr kernel_thunk = BuildKernelThunk(hlo); - TF_RETURN_IF_ERROR(EmitTargetElementLoopInThunk( - *hlo, - [=](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*init_value, *hlo) - .EmitReadArrayElement(index, &ir_builder_); - }, - kernel_thunk.get())); + LaunchDimensions launch_dimensions = + CalculateLaunchDimensions(ShapeUtil::GetSubshape(hlo->shape(), index), + ir_emitter_context_->device_description()); + UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), + ir_emitter_context_->llvm_module()); + // If the init_value was fused into this reduce we have to generate it first. + if (fused && init_value_operand->opcode() != HloOpcode::kParameter) { + CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode()); + TF_RETURN_IF_ERROR(HandleConstant(const_cast(init_value))); + } + TF_RETURN_IF_ERROR(ParallelLoopEmitter( + [=](const llvm_ir::IrArray::Index& index) { + return GetIrArray(*init_value, *hlo) + .EmitReadArrayElement(index, &ir_builder_); + }, + GetIrArray(*hlo, *hlo, index), launch_dimensions, + &ir_builder_) + .EmitLoop(IrName(hlo))); // Clean up state left behind by emitting the loop above. (This is normally // done in IrEmitterUnnested::Postprocess().) @@ -2697,7 +2940,9 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( if (!hlo.IsMultiOutputFusion()) { return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo), launch_dimensions, &ir_builder_, unroll_factor) - .EmitLoop(IrName(&hlo)); + .EmitLoop(IrName(&hlo), + GetIndexTypeForKernel(&hlo, launch_dimensions.launch_bound(), + &ir_builder_)); } // For multiple outputs fusion, we need to emit each operand and the root. @@ -2705,10 +2950,12 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) { output_arrays.push_back(GetIrArray(hlo, hlo, {i})); } - TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, output_arrays, - launch_dimensions, &ir_builder_, - unroll_factor) - .EmitLoop(IrName(&hlo))); + TF_RETURN_IF_ERROR( + ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions, + &ir_builder_, unroll_factor) + .EmitLoop(IrName(&hlo), + GetIndexTypeForKernel( + &hlo, launch_dimensions.launch_bound(), &ir_builder_))); std::vector tuple_operand_ptrs; for (int64 i = 0; i < output_arrays.size(); ++i) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index b41eaa303b0aad104ad0369438e192fa404d7878..279a5c386ad15857e0a0f6ae18ccf7cc5183e0a6 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -67,7 +67,6 @@ class IrEmitterUnnested : public IrEmitter { Status HandleDot(HloInstruction* dot) override; Status HandleFft(HloInstruction* fft) override; Status HandleFusion(HloInstruction* fusion) override; - Status HandleGather(HloInstruction* gather) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleReduce(HloInstruction* reduce) override; Status HandleSelectAndScatter(HloInstruction* instruction) override; @@ -77,6 +76,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleRng(HloInstruction* random) override; Status HandleSelect(HloInstruction* select) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleGenerateToken(HloInstruction* gen_token) override; Status EmitTargetElementLoop( const HloInstruction& hlo, @@ -100,6 +100,13 @@ class IrEmitterUnnested : public IrEmitter { const HloInstruction& inst, tensorflow::gtl::ArraySlice args); + // Helper for writing extra outputs from inside a reduce kernel. + Status EmitExtraOutputsForReduce( + const HloInstruction* reduce, const llvm_ir::IrArray::Index& index, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); + // EmitColumnReduction and EmitRowReduction emit code for column and row // reduction of a matrix and/or 3D tensor. Row and column reduction have // different memory access pattern, so for performance their implementations @@ -115,7 +122,11 @@ class IrEmitterUnnested : public IrEmitter { const Shape& input_shape, tensorflow::gtl::ArraySlice input_gens, tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers); + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); // Emits code that reduces a 3D tensor of shape [depth x height x width] to a // vector of shape [height]. Other parameters have the same meaning as those @@ -127,14 +138,22 @@ class IrEmitterUnnested : public IrEmitter { const Shape& input_shape, tensorflow::gtl::ArraySlice input_gens, tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers); + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); // Emits code that reduces a tensor of arbitrary rank to a scalar. Status EmitReductionToScalar( HloInstruction* reduce, const Shape& input_shape, tensorflow::gtl::ArraySlice input_gens, tensorflow::gtl::ArraySlice init_value_gens, - tensorflow::gtl::ArraySlice reducers); + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); // Figures out whether `reduce` is a row or column reduction, and which // dimensions to reduce, and calls either `EmitRowReduction` or @@ -147,13 +166,21 @@ class IrEmitterUnnested : public IrEmitter { // Multiple reduces can be emitted in the same loop, assuming they have the // same input and output shapes, and the same reduce dimensions. // + // extra_output_gens can contain extra generators for intermediate outputs. + // These must have the same shape as the reduce input as they are computed + // when the reduce inputs are being read. + // // Prerequisite: `IsReductionToVector(*reduce)` Status EmitReductionToVector( HloInstruction* reduce, const Shape& input_shape, tensorflow::gtl::ArraySlice input_gens, tensorflow::gtl::ArraySlice init_value_gens, tensorflow::gtl::ArraySlice dimensions_to_reduce, - tensorflow::gtl::ArraySlice reducers); + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); // Returns a KernelThunk that invokes the kernel emitted for `inst`. The // caller needs to make sure `inst` outlives the lifetime of the returned diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc new file mode 100644 index 0000000000000000000000000000000000000000..9a4a1541ca3623b49b621ddde5236efb5dbbeaac --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -0,0 +1,270 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace gpu { + +GpuMultiOutputFusion::GpuMultiOutputFusion() : MultiOutputFusion(INT64_MAX) {} + +bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, + HloInstruction* instr2) { + auto get_element_instr = + [&](const HloInstruction* instr) -> const HloInstruction* { + const HloInstruction* element_instr = instr; + if (instr->opcode() == HloOpcode::kFusion) { + auto fused_expression_root = instr->fused_expression_root(); + if (instr->IsMultiOutputFusion()) { + // If possible, we want to pick a reduce operand of the fusion root, + // because it has the most constraints. + for (const auto* inst : fused_expression_root->operands()) { + if (inst->opcode() == HloOpcode::kReduce) { + return inst; + } + } + return fused_expression_root->operands()[0]; + } else { + element_instr = fused_expression_root; + } + } + return element_instr; + }; + + auto get_element_shape = [&](const HloInstruction* element_instr) { + // Special handling of kReduce instructions -- the fusion + // applies to the first operand. + if (element_instr->opcode() == HloOpcode::kReduce) { + return element_instr->operand(0)->shape(); + } + return element_instr->shape(); + }; + + // The shapes in all tuple operands should agree, unless it is a reduce. + // In that case, the operand of the reduce needs to have the same shape + // as the other tuple operands, but also we need to compare the output + // shapes of the reduces. + // TODO(tjoerg): Allow differences in fp precision. + auto* element_instr_1 = get_element_instr(instr1); + auto* element_instr_2 = get_element_instr(instr2); + if (element_instr_1->opcode() == HloOpcode::kReduce && + element_instr_2->opcode() == HloOpcode::kReduce && + !ShapeUtil::Equal(element_instr_1->shape(), element_instr_2->shape())) { + return false; + } + // The elementwise output shapes must be the same (including layout). + return ShapeUtil::Equal(get_element_shape(element_instr_1), + get_element_shape(element_instr_2)); +} + +namespace { +bool IsReduction(HloInstruction* instr) { + if (instr->IsMultiOutputFusion()) { + for (const HloInstruction* operand : + instr->fused_expression_root()->operands()) { + if (operand->opcode() == HloOpcode::kReduce) { + return true; + } + } + return false; + } else if (instr->opcode() == HloOpcode::kFusion) { + return instr->fused_expression_root()->opcode() == HloOpcode::kReduce; + } else { + return instr->opcode() == HloOpcode::kReduce; + } +} +} // namespace + +bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { + // We can fuse reduces and loop fusions. + return IsReduction(instr) || + (instr->opcode() == HloOpcode::kFusion && + instr->fusion_kind() == HloInstruction::FusionKind::kLoop && + // TODO(b/110202584): bitcasts make nested fusions, GPU has no support + // for nested fusions. + instr->fused_expression_root()->opcode() != HloOpcode::kBitcast); +} + +int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, + HloInstruction* instr2) { + tensorflow::gtl::FlatSet in_list; + for (auto instr : instr1->operands()) { + if (!IsProfitableOperand(instr)) { + continue; + } + in_list.insert(instr); + } + int64 profit = 0; + for (auto instr : instr2->operands()) { + if (!IsProfitableOperand(instr) || in_list.count(instr) == 0) { + continue; + } + profit += ShapeUtil::ByteSizeOf(instr->shape()); + } + VLOG(2) << "Fusing instr1=" << instr1->name() << " instr2=" << instr2->name() + << ", the profit is =" << profit; + return profit; +} + +bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1, + HloInstruction* instr2) { + if (!MultiOutputFusion::LegalToFuse(instr1, instr2)) { + return false; + } + // If we're fusing fusions only do it if the fusion kind matches. Loop fusions + // merge into bigger loop fusions and input (reduce) fusions become fusions + // with multiple reduce outputs. We could fuse reduce and loop fusions + // together too (the result being an input fusion) if we find cases where this + // improves things. + CHECK(instr1->opcode() == HloOpcode::kFusion); + if (instr2->opcode() == HloOpcode::kFusion) { + return instr1->fusion_kind() == instr2->fusion_kind(); + } + return instr1->fusion_kind() != HloInstruction::FusionKind::kLoop; +} + +bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { + bool changed = false; + RecomputeReachability(); + + tensorflow::gtl::FlatSet to_fuse; + // Keep a list of the instructions to fuse after making all the fusion + // decisions. We first aggressively add instructions to potential_fusion_list, + // then filter out instructions that will be no longer fusable because of + // reachability change. This avoids recalculating reachability on a large set + // of instructions. + std::vector> + potential_fusion_list; + std::vector> fusion_list; + std::vector instrs_to_update_reachability; + + // For each reduce or reduce multi-output fusion, try to fuse it with loop + // fusions operands. + for (HloInstruction* consumer : computation()->MakeInstructionPostOrder()) { + if (consumer->user_count() == 0) { + continue; + } + if (!IsReduction(consumer)) { + continue; + } + // TODO(b/110517657): Lowering multi-output reduce fusions with bfloat16 + // output element types is not supported on GPU. However, bfloat16 is used + // in shared tests. + if (consumer->shape().element_type() == PrimitiveType::BF16) { + continue; + } + + auto consumer_operands = consumer->operands(); + for (size_t i = 0; i < consumer_operands.size(); ++i) { + HloInstruction* producer = consumer_operands[i]; + if (!producer->IsFusable()) { + continue; + } + const bool is_loop_fusion = + producer->opcode() == HloOpcode::kFusion && + producer->fusion_kind() == HloInstruction::FusionKind::kLoop; + if (!is_loop_fusion) { + continue; + } + if (!ShapesCompatibleForFusion(producer, consumer)) { + continue; + } + // If we have already decided to fuse this producer, skip it. + if (ContainsKey(to_fuse, producer)) { + continue; + } + // Do not fuse a producer if the other operands of the fusion are + // reachable from the producer, this would create a cycle. + if (std::any_of(consumer_operands.begin(), consumer_operands.end(), + [&](HloInstruction* operand) { + return producer != operand && + reachability()->IsReachable(producer, operand); + })) { + continue; + } + to_fuse.insert(producer); + potential_fusion_list.emplace_back(producer, consumer); + instrs_to_update_reachability.push_back(producer); + instrs_to_update_reachability.push_back(consumer); + break; + } + } + + // Filter out pairs that will be no longer fusable because of reachability + // change. + for (auto& fusion_pair : potential_fusion_list) { + HloInstruction* producer = fusion_pair.first; + HloInstruction* consumer = fusion_pair.second; + bool fusable = true; + for (size_t i = 0; i < consumer->operand_count(); ++i) { + if (producer != consumer->operand(i) && + reachability()->IsReachable(producer, consumer->operand(i))) { + fusable = false; + break; + } + } + if (fusable) { + UpdateReachability(producer, consumer, instrs_to_update_reachability); + fusion_list.push_back(fusion_pair); + } + } + + for (auto fusions_to_create : fusion_list) { + HloInstruction* producer = fusions_to_create.first; + HloInstruction* consumer = fusions_to_create.second; + if (consumer->opcode() != HloOpcode::kFusion) { + // Fusing with a reduce (fusion) always results in an input fusion. + HloInstruction* input_fusion = + computation()->AddInstruction(HloInstruction::CreateFusion( + consumer->shape(), HloInstruction::FusionKind::kInput, consumer)); + VLOG(2) << "Fuse producer " << producer->name() << " and its consumer " + << consumer->name() << " into " << input_fusion->name(); + TF_CHECK_OK(computation()->ReplaceInstruction(consumer, input_fusion)); + if (producer->opcode() == HloOpcode::kFusion) { + input_fusion->MergeFusionInstructionIntoMultiOutput(producer); + } else { + input_fusion->FuseInstructionIntoMultiOutput(producer); + } + } else { + VLOG(2) << "Fuse producer " << producer->name() << " into its consumer " + << consumer->name(); + + if (producer->opcode() == HloOpcode::kFusion) { + consumer->MergeFusionInstructionIntoMultiOutput(producer); + } else { + consumer->FuseInstructionIntoMultiOutput(producer); + } + } + changed = true; + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h new file mode 100644 index 0000000000000000000000000000000000000000..67ca5d49eee8508e93284b134f8410eb3a89f9ce --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ + +#include "tensorflow/compiler/xla/service/multi_output_fusion.h" + +namespace xla { +namespace gpu { + +// Multi-output fusion of sibling and producer-consumer instructions for the +// Jellyfish backend. +class GpuMultiOutputFusion : public MultiOutputFusion { + public: + GpuMultiOutputFusion(); + + protected: + // Test if instr1 and instr2 have the compatible shapes that can be legally + // fused. + bool ShapesCompatibleForFusion(HloInstruction* instr1, + HloInstruction* instr2) override; + + // We currently only consider reduce and reduce fusion nodes as candidates. + bool IsFusible(HloInstruction* instr) override; + + // This function estimates the amount of memory reads saved by merging + // instr1 and instr2 into one multi-output fusion instruction. For a fusion + // instruction, all the operands need to be loaded from memory. If we merge + // instr1 and instr2, common operands will not be loaded twice. The profit is + // estimated as the size of the common operands b/w instr1 and instr2. + int64 GetProfit(HloInstruction* instr1, HloInstruction* instr2) override; + + // Test if it's legal to fuse instr1 and instr2 into one fusion instruction. + bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2) override; + + // Fuse loop fusions into reduce fusions. + bool DoProducerConsumerMultiOutputFusion() override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bca277946470d65dbd12781b11432b566eb9ae90 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -0,0 +1,327 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace gpu { + +using InstructionFusionTest = HloTestBase; + +const char kModulePrefix[] = R"( + HloModule test_module + + scalar_add_computation { + scalar_lhs.0 = f32[] parameter(0) + scalar_rhs.0 = f32[] parameter(1) + ROOT add.0 = f32[] add(scalar_lhs.0, scalar_rhs.0) + } + scalar_mul_computation { + scalar_lhs.1 = f32[] parameter(0) + scalar_rhs.1 = f32[] parameter(1) + ROOT mul.1 = f32[] add(scalar_lhs.1, scalar_rhs.1) + })"; + +TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) { + // Fusion with reduce instruction root and a sibling reduce instruction + // sharing the same input param. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation { + p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + const.2 = f32[] constant(1) + fusion = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation + reduce.2 = f32[512]{0} reduce(p1, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation + ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion, reduce.2) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce())); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceInputShapes) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p1.1 = f32[6400]{0} parameter(1) + mul = f32[6400]{0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + ROOT reduce.1 = f32[] reduce(mul, const.1), dimensions={0}, to_apply=scalar_add_computation + } + + fused_computation_2 { + p1.2 = f32[6400]{0} parameter(1) + r1 = f32[64,100]{0,1} reshape(p1.2) + const.2 = f32[] parameter(0) + ROOT reduce.2 = f32[] reduce(r1, const.2), dimensions={1,0}, to_apply=scalar_mul_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1 = f32[6400]{0} parameter(1) + fusion.1 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_1 + fusion.2 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_2 + ROOT root = (f32[], f32[]) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceOutputShapes) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p1.1 = f32[10,10]{1,0} parameter(1) + mul = f32[10,10]{1,0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + ROOT reduce.1 = f32[] reduce(mul, const.1), dimensions={0,1}, to_apply=scalar_add_computation + } + + fused_computation_2 { + p1.2 = f32[10,10]{1,0} parameter(1) + const.2 = f32[10]{0} parameter(0) + ROOT reduce.2 = f32[10]{0} reduce(p1.2, const.2), dimensions={0}, to_apply=scalar_mul_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1.3 = f32[10,10]{1,0} parameter(1) + fusion.1 = f32[] fusion(p0, p1.3), kind=kInput, calls=fused_computation_1 + p2 = f32[] parameter(2) + fusion.2 = f32[10]{0} fusion(p2, p1.3), kind=kInput, calls=fused_computation_2 + ROOT root = (f32[], f32[10]{0}) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceFusions) { + // Two sibling fusions with reduce instruction roots sharing the same input + // param. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation + } + + fused_computation_2 { + p1.2 = f32[128,512,28,28]{3,2,1,0} parameter(1) + const.2 = f32[] parameter(0) + ROOT reduce.2 = f32[512]{0} reduce(p1.2, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + fusion.1 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_1 + fusion.2 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_2 + ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce())); +} + +TEST_F(InstructionFusionTest, + MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) { + // Multi-output fusion with two reduce instructions root and a sibling reduce + // instruction sharing the same input param. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation (p0: f32[128,512,28,28]) -> (f32[512], f32[512]) { + const.1 = f32[] constant(1) + p0.1 = f32[128,512,28,28]{3,2,1,0} parameter(0) + mul = f32[128,512,28,28]{3,2,1,0} multiply(f32[128,512,28,28]{3,2,1,0} p0.1, f32[128,512,28,28]{3,2,1,0} p0.1) + reduce.1 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} mul, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation + reduce.2 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} p0.1, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation + ROOT tuple = (f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} reduce.1, f32[512]{0} reduce.2) + } + + ENTRY entry (p0: f32[128,512,28,28]) -> (f32[512], f32[512], f32[512]) { + p0 = f32[128,512,28,28]{3,2,1,0} parameter(0) + const = f32[] constant(1) + fusion = (f32[512]{0}, f32[512]{0}) fusion(f32[128,512,28,28]{3,2,1,0} p0), kind=kInput, calls=fused_computation + get-tuple-element = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=0 + get-tuple-element.1 = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=1 + reduce.3 = f32[512]{0} reduce(p0, const), dimensions={0,2,3}, to_apply=scalar_add_computation + ROOT root = (f32[512]{0}, f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} get-tuple-element, f32[512]{0} get-tuple-element.1, f32[512]{0} reduce.3) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce(), op::Reduce())); +} + +TEST_F(InstructionFusionTest, + MultiOutputFusionSiblingFusionCheckAgainstReduceOperand) { + // Verify that if we already have a multi-output fusion that we prefer to pick + // a reduce op from its operands for checking shape compatibility. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p1.1 = f32[10,10]{1,0} parameter(1) + mul = f32[10,10]{1,0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + reduce.1 = f32[] reduce(p1.1, const.1), dimensions={0,1}, to_apply=scalar_add_computation + ROOT tuple = (f32[10,10], f32[]) tuple(mul, reduce.1) + } + + fused_computation_2 { + p1.2 = f32[10,10]{1,0} parameter(1) + const.2 = f32[10] parameter(0) + ROOT reduce.2 = f32[10] reduce(p1.2, const.2), dimensions={0}, to_apply=scalar_mul_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1 = f32[10,10]{1,0} parameter(1) + p2 = f32[10]{0} parameter(2) + fusion.1 = (f32[10,10], f32[10]) fusion(p0, p1), kind=kInput, calls=fused_computation_1 + get-tuple-element.1 = f32[10,10] get-tuple-element((f32[10,10], f32[10]) fusion.1), index=0 + get-tuple-element.2 = f32[] get-tuple-element((f32[10,10], f32[10]) fusion.1), index=1 + fusion.2 = f32[10] fusion(p2, p1), kind=kInput, calls=fused_computation_2 + ROOT root = (f32[10,10], f32[], f32[10]) tuple(get-tuple-element.1, get-tuple-element.2, fusion.2) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionTwoLoops) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + fused_computation_2 { + p0.2 = f32[6400]{0} parameter(0) + const.2 = f32[] constant(1) + ROOT div = f32[6400]{0} divide(p0.2, const.2) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + fusion.2 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_2 + ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Multiply(), op::Divide())); +} + +TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduce) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_add { + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + p1.1 = f32[2,2,2]{2,1,0} parameter(1) + ROOT add = f32[2,2,2]{2,1,0} add(p0.1, p1.1) + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + c0 = f32[] constant(0) + add = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_add + reduce = f32[2,2]{1,0} reduce(add, c0), dimensions={2}, to_apply=scalar_add_computation + ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, add) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement())); + const HloInstruction* fusion = root->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Add())); +} + +TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_select { + p1.1 = f32[2,2,2]{2,1,0} parameter(1) + c0 = f32[] constant(0) + broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={} + greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast) + p0.1 = f32[2,2,2]{2,1,0} parameter(0) + ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast) + } + + fused_reduce { + p0.2 = f32[2,2,2]{2,1,0} parameter(0) + c1 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0.2, c1), dimensions={2}, to_apply=scalar_add_computation + mul = f32[2,2,2]{2,1,0} multiply(p0.2, p0.2) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=scalar_add_computation + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + + ENTRY reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + p1 = f32[2,2,2]{2,1,0} parameter(1) + select = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select + fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput, calls=fused_reduce + gte0 = f32[2,2]{1,0} get-tuple-element(fusion), index=0 + gte1 = f32[2,2]{1,0} get-tuple-element(fusion), index=1 + ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(gte1, gte1, select) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement(), + op::GetTupleElement())); + const HloInstruction* fusion = root->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce(), op::Select())); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index d8c07dc3119fb81a3ef22822acb11b7c4d5bbca5..cd833ec7bd858aabee84ac306d198e80eb112506 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -58,7 +58,7 @@ ParallelLoopEmitter::ParallelLoopEmitter( std::vector ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name) { + tensorflow::StringPiece loop_name, llvm::Type* index_type) { // Emit the following code in LLVM IR: // linear_index = blockIdx.x * blockDim.x + threadIdx.x; // if (linear_index < num_elements) { @@ -71,14 +71,13 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( // // %nctaid.x is currently specified as 2147483647. VLOG(3) << "EmitIndexAndSetExitBasicBlock unroll_factor " << unroll_factor_; + CHECK_NE(index_type, nullptr); std::vector array_indices; - llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, ir_builder_); llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_count(), static_cast(block_id)); - block_id = - ir_builder_->CreateZExt(block_id, ir_builder_->getInt64Ty(), "block_id"); + block_id = ir_builder_->CreateZExtOrTrunc(block_id, index_type, "block_id"); // Per the PTX documentation: // "It is guaranteed that [...] 0 <= %tid.x < %ntid.x" @@ -88,13 +87,15 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, ir_builder_); llvm_ir::AddRangeMetadata(0, launch_dimensions_.threads_per_block(), static_cast(thread_id)); - thread_id = ir_builder_->CreateZExt(thread_id, ir_builder_->getInt64Ty(), - "thread_id"); + thread_id = + ir_builder_->CreateZExtOrTrunc(thread_id, index_type, "thread_id"); llvm::Value* linear_index_base = ir_builder_->CreateAdd( ir_builder_->CreateMul( block_id, - ir_builder_->getInt64(launch_dimensions_.threads_per_block()), "", + llvm::ConstantInt::get(index_type, + launch_dimensions_.threads_per_block()), + "", /*HasNUW=*/true, /*HasNSW=*/true), thread_id, "linear_index", /*HasNUW=*/true, /*HasNSW=*/true); @@ -110,21 +111,23 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( llvm::Intrinsic::assume, {ir_builder_->CreateICmpULT( linear_index_base, - ir_builder_->getInt64(launch_dimensions_.threads_per_block() * - launch_dimensions_.block_count()), + llvm::ConstantInt::get(index_type, + launch_dimensions_.threads_per_block() * + launch_dimensions_.block_count()), "linear_index_in_range")}, {}, ir_builder_); if (unroll_factor_ > 1) { linear_index_base = ir_builder_->CreateMul( - linear_index_base, ir_builder_->getInt64(unroll_factor_), + linear_index_base, llvm::ConstantInt::get(index_type, unroll_factor_), "linear_index_base", /*HasNUW=*/true, /*HasNSW=*/true); } array_indices.emplace_back(linear_index_base, shape_, ir_builder_); for (int i = 1; i < unroll_factor_; ++i) { llvm::Value* linear_index = ir_builder_->CreateAdd( - linear_index_base, ir_builder_->getInt64(i), "linear_index", + linear_index_base, llvm::ConstantInt::get(index_type, i), + "linear_index", /*HasNUW=*/true, /*HasNSW=*/true); array_indices.emplace_back(linear_index, shape_, ir_builder_); } @@ -132,7 +135,7 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( auto if_in_bounds = llvm_ir::EmitIfThenElse( ir_builder_->CreateICmpULT( linear_index_base, - ir_builder_->getInt64(ShapeUtil::ElementsIn(shape_))), + llvm::ConstantInt::get(index_type, ShapeUtil::ElementsIn(shape_))), llvm_ir::IrName(loop_name, "in_bounds"), ir_builder_, false); // Set exit_bb_ to the exit block of the if structure. diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h index 25318b3bed8bf4a2dfe3a4a974269d0405c3bfec..302e1bf1bc8e90f2eebd838f156a1552e86185ac 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -58,7 +58,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ~ParallelLoopEmitter() override = default; std::vector EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name) override; + tensorflow::StringPiece loop_name, llvm::Type* index_type) override; private: // The thread and block dimension to parallelize the loop on. diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.h b/tensorflow/compiler/xla/service/gpu/partition_assignment.h index c125474edb1036090a926020f2b1e7fcf64c751a..02471129e004b4876ce20a62cade34060c65b478 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.h @@ -47,6 +47,7 @@ class LaunchDimensions { int64 block_count() const { return block_count_; } int64 threads_per_block() const { return threads_per_block_; } + int64 launch_bound() const { return block_count() * threads_per_block(); } private: int64 block_count_; diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 696fa7e0194032b5c78bf11383c3280a62de07fa..6f4bb0580e8dfc1dce1cca0a60cc3dd9ea600fb3 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -33,8 +33,7 @@ class StreamAssignmentTest : public HloTestBase { auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); config.set_debug_options(debug_options); - return MakeUnique("test_module", VersionedComputationHandle(), - config); + return MakeUnique("test_module", config); } // Pre-canned shapes. diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 06a5e0351b63270b61b998ca2211f480f256f759..a04aa4069d2344ca7b2e763cfeeb53abcbefc21d 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -26,6 +26,46 @@ namespace xla { using tensorflow::gtl::FlatMap; using tensorflow::gtl::FlatSet; +/*static*/ +StatusOr HeapSimulator::MinimumMemoryForModule( + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const LogicalBuffer::SizeFunction& size_function) { + if (module_sequence.empty()) { + return 0; + } + + const HloModule* module = module_sequence.begin()->first->parent(); + TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, + TuplePointsToAnalysis::Run(module)); + + // The absolute minimum memory required for a given sequence of instructions + // is determined by the sequence of Alloc and Free calls on a simulated heap, + // ignoring fragmentation. We run the heap simulation on the whole module, + // rather than summing each computation, since it gives us a better lower + // bound, by minimizing the liveness of sub-computations. + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique(), *module, + module_sequence, *points_to_analysis, size_function)); + return result.heap_size; +} + +/*static*/ +StatusOr HeapSimulator::MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap* + memory_by_computation) { + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique(), computation, + sequence, points_to_analysis, size_function, + HeapSimulator::Options(), memory_by_computation)); + return result.heap_size; +} + /*static*/ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloModule& module, @@ -46,9 +86,11 @@ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloComputation& computation, const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, - const BufferValue::SizeFunction& size_fn, const Options& options) { + const BufferValue::SizeFunction& size_fn, const Options& options, + const tensorflow::gtl::FlatMap* + memory_by_computation) { HeapSimulator heap(std::move(algorithm), size_fn, options, - /*module_sequence=*/nullptr); + /*module_sequence=*/nullptr, memory_by_computation); TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, points_to_analysis)); return heap.Finish(); @@ -219,6 +261,12 @@ Status HeapSimulator::RunComputation( Alloc(buffer, instruction); } } + // Account for the memory used by subcomputations when estimating the + // current heap size. + if (memory_by_computation_ != nullptr) { + algorithm_->AccountForSubcomputationMemory(instruction, + *memory_by_computation_); + } // If the whole module is sequential, we can save memory by running the // heap-simulation for sub-computations inline. E.g. the buffers for the @@ -286,12 +334,15 @@ Status HeapSimulator::RunComputation( HeapSimulator::HeapSimulator( std::unique_ptr algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, - const SequentialHloOrdering::HloModuleSequence* module_sequence) + const SequentialHloOrdering::HloModuleSequence* module_sequence, + const tensorflow::gtl::FlatMap* + memory_by_computation) : no_fragmentation_stats_(MakeUnique()), algorithm_(std::move(algorithm)), size_fn_(size_fn), options_(options), - module_sequence_(module_sequence) { + module_sequence_(module_sequence), + memory_by_computation_(memory_by_computation) { debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr); } @@ -460,6 +511,26 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) { } } +void NoFragmentationStatsHeap::AccountForSubcomputationMemory( + const HloInstruction* instruction, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + // We only count the memory usage of the largest subcomputation, instead of + // adding them all, because subcomputations won't execute in parallel. + int64 max_subcomputation_bytes = 0; + for (const auto* c : instruction->called_computations()) { + auto it = memory_by_computation.find(c); + if (it != memory_by_computation.end()) { + int64 subcomputation_bytes = it->second; + if (subcomputation_bytes > max_subcomputation_bytes) { + max_subcomputation_bytes = subcomputation_bytes; + } + } + } + max_heap_size_ = + std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes); +} + void NoFragmentationStatsHeap::Free(const BufferValue* buffer, int64 size) { current_heap_size_ -= size; } diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index 8b2b43a37a5c41d334e5338c6a6fad160f03a51e..811a6042df9434ac3f4bed71b9c093433e25c1bb 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -85,6 +85,23 @@ class HeapSimulator { const BufferValueFlatSet* buffers_to_assign; }; + // Returns the minimum memory required to compute an HLO module where all + // computations have been scheduled (represented by the given + // module_sequence), assuming no fragmentation. + static StatusOr MinimumMemoryForModule( + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const LogicalBuffer::SizeFunction& size_function); + + // Returns the minimum memory required to compute the given computation, + // assuming no fragmentation. + static StatusOr MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap* + memory_by_computation = nullptr); + // Run the heap simulation with the given algorithm, assuming the given // module_sequence, which must contain a topologically-consistent total // ordering of all instructions within each computation. The result is invalid @@ -111,7 +128,9 @@ class HeapSimulator { const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, - const Options& options = Options()); + const Options& options = Options(), + const tensorflow::gtl::FlatMap* + memory_by_computation = nullptr); private: // If 'module_sequence' is non-null, it is used to find kCall and kWhile @@ -120,7 +139,9 @@ class HeapSimulator { HeapSimulator( std::unique_ptr algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, - const SequentialHloOrdering::HloModuleSequence* module_sequence); + const SequentialHloOrdering::HloModuleSequence* module_sequence = nullptr, + const tensorflow::gtl::FlatMap* + memory_by_computation = nullptr); ~HeapSimulator(); Status RunComputation( @@ -144,7 +165,13 @@ class HeapSimulator { const std::unique_ptr algorithm_; const BufferValue::SizeFunction size_fn_; const Options options_; + // module_sequence_ is set by buffer assignment, and memory_by_computation_ is + // set by hlo scheduling. Then, in RunComputation, we check both in order to + // handle subcomputations. It would be good to unify the handling of + // subcomputations, but it's not clear how. const SequentialHloOrdering::HloModuleSequence* module_sequence_; + const tensorflow::gtl::FlatMap* + memory_by_computation_; // In addition to Alloc and Free, the heap simulator exposes a concept of // buffer sharing. When ShareBuffer is called, instead of allocating new @@ -189,6 +216,11 @@ class HeapAlgorithm { // Alloc allocates a buffer of 'size' bytes. virtual void Alloc(const BufferValue* buffer, int64 size) = 0; + virtual void AccountForSubcomputationMemory( + const HloInstruction* instruction, + const tensorflow::gtl::FlatMap& + memory_by_computation) {} + // Free de-allocates a previously allocated buffer. virtual void Free(const BufferValue* buffer, int64 size) = 0; @@ -207,7 +239,14 @@ class NoFragmentationStatsHeap : public HeapAlgorithm { ~NoFragmentationStatsHeap() override = default; void Alloc(const BufferValue* buffer, int64 size) override; + + void AccountForSubcomputationMemory( + const HloInstruction* instruction, + const tensorflow::gtl::FlatMap& + memory_by_computation) override; + void Free(const BufferValue* buffer, int64 size) override; + Result Finish() override; private: diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 6271652412c2979ff926702f12722102344b0dfb..93d7a141258a3186b10cf2728b70a034488a84f2 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -34,6 +34,65 @@ limitations under the License. namespace xla { namespace { +class MinimumMemoryForSequenceTest : public HloTestBase {}; + +TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); + + auto cond_builder = HloComputation::Builder("WhileCond"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); + HloInstruction* cond_iter = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); + HloInstruction* cond_data = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); + // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) + HloInstruction* cond_lt = cond_builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kLt, cond_iter, cond_data)); + HloComputation* cond_computation = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto body_builder = HloComputation::Builder("WhileBody"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "body_param")); + HloComputation* body_computation = + module->AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + // Entry params: 8 bytes (4 bytes per param), TOTAL=8 + HloInstruction* iter = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param_iter")); + HloInstruction* data = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "param_data")); + // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24 + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({iter, data})); + // While: 8 bytes (4 bytes per element), TOTAL=32 + // Both cond and body use a max of 24 bytes, TOTAL=56 + HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( + tuple_shape, cond_computation, body_computation, tuple)); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + }; + + SequentialHloOrdering::HloModuleSequence module_sequence; + module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, + cond_lt}; + module_sequence[body_computation] = {body_param}; + module_sequence[entry_computation] = {iter, data, tuple, while_op}; + EXPECT_EQ(56, HeapSimulator::MinimumMemoryForModule(module_sequence, size_fn) + .ValueOrDie()); +} + const char kAlloc[] = "Alloc"; const char kFree[] = "Free"; const char kFinish[] = "Finish"; diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 1f7c1cffd324ad2f4e4cdf11046c8459b8ceb6d5..d2417910606fdd13223076d33ff1bda1dd291d98 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -150,6 +150,11 @@ message HloInstructionProto { // Backend configuration for the instruction. Has backend-specific meaning. string backend_config = 43; + + // Cross Replica Sum fields. + repeated int64 replica_group_ids = 44; + int64 all_reduce_id = 45; + string cross_replica_sum_barrier = 46; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index a88283ed9a6459b4fa9310e160b59c77d51f1027..0a948cc390fed7daed3e0cc938bf59cbcfd9b4df 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -493,6 +493,16 @@ StatusOr> HloAliasAnalysis::Run( bool HloAliasAnalysis::HasLiveRangeInterference( const HloOrdering& ordering) const { for (const HloBuffer& buffer : buffers()) { + CHECK(!buffer.values().empty()); + if (ShapeUtil::IsToken(buffer.values().front()->shape())) { + // Tokens have no on-device representation and cannot interfere. + for (const HloValue* value : buffer.values()) { + // If one of the values is a token, all values must be a token. + DCHECK(ShapeUtil::IsToken(value->shape())); + } + continue; + } + // Check that the values in the buffer are totally ordered with respect to // 'ordering'. Begin by sorting the values with respect to 'ordering' with a // tie-break using value ID. The tie-break is necessary because we need a @@ -517,7 +527,6 @@ bool HloAliasAnalysis::HasLiveRangeInterference( // a buffer and A interferes with C, then necessarily A also interferes // with B. So to check interference you only need to check interference // between A and B, and between B and C. - CHECK(!values.empty()); for (int i = 1; i < values.size(); ++i) { if (!ordering.IsDefinedBefore(*values[i - 1], *values[i])) { VLOG(1) << values[i - 1]->ToShortString() << " and " diff --git a/tensorflow/compiler/xla/service/hlo_casting_utils.h b/tensorflow/compiler/xla/service/hlo_casting_utils.h index b15f1f24c607715f5483df492748fe1ca1dccefa..7f73bba036534a62a70a80431236cffa766c9b38 100644 --- a/tensorflow/compiler/xla/service/hlo_casting_utils.h +++ b/tensorflow/compiler/xla/service/hlo_casting_utils.h @@ -18,10 +18,13 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_ -#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include +#include "tensorflow/core/platform/logging.h" namespace xla { +class HloInstruction; + template using EnableIfDerivedFromHlo = typename std::enable_if::value>::type; diff --git a/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc b/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc index 436a9222342dd853c8adb545632c5c39f577ba97..a3364275409122254bf99b40a7d2fcbb2d7564cc 100644 --- a/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/core/platform/test.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index b61eabbbf526249710ee434565bb68a493a089d5..c057be82014b00e3ff63f835fcb78c08f8d9c154 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -64,7 +64,7 @@ HloComputation::HloComputation( const string& name, int parameter_count, std::vector>* instructions, HloInstruction* root_instruction, HloInstruction* fusion_instruction) - : name_(name), + : name_(NameUniquer::GetSanitizedName(name)), unique_id_(-1), root_instruction_(root_instruction), fusion_instruction_(fusion_instruction) { @@ -234,7 +234,6 @@ Status HloComputation::RemoveInstruction(HloInstruction* instruction) { TF_RET_CHECK(instruction_iterators_.count(instruction) != 0); auto inst_it = instruction_iterators_.at(instruction); (*inst_it)->set_parent(nullptr); - instruction->DetachFromOperands(); instructions_.erase(inst_it); return Status::OK(); } @@ -264,46 +263,11 @@ void HloComputation::set_root_instruction( namespace { -// Helper class which computes the post order of an expression rooted at a -// particular instruction. -class InstructionPostOrderer : public DfsHloVisitorWithDefault { - public: - // added_instructions is the set of instructions which have already been - // accounted for in the post order in previous invocations of - // GetOrder. Without this mechanism, instructions which are predecessors of - // multiple root instructions of the computation can be added to the post - // order more than once. - static std::list GetOrder( - HloInstruction* root, - tensorflow::gtl::FlatSet* added_instructions) { - InstructionPostOrderer orderer(added_instructions); - TF_CHECK_OK(root->Accept(&orderer)); - return std::move(orderer.post_order_); - } - - private: - explicit InstructionPostOrderer( - tensorflow::gtl::FlatSet* added_instructions) - : added_instructions_(added_instructions) {} - ~InstructionPostOrderer() override {} - - Status DefaultAction(HloInstruction* hlo_instruction) override { - if (added_instructions_->count(hlo_instruction) == 0) { - post_order_.push_back(hlo_instruction); - added_instructions_->insert(hlo_instruction); - } - return Status::OK(); - } - - std::list post_order_; - tensorflow::gtl::FlatSet* added_instructions_; -}; - // Helper which builds a post order of the HLO call graph. void ComputeComputationPostOrder( HloComputation* computation, tensorflow::gtl::FlatSet* visited, - std::list* post_order) { + std::vector* post_order) { if (visited->insert(computation).second) { for (auto* instruction : computation->instructions()) { for (HloComputation* called_computation : @@ -315,12 +279,53 @@ void ComputeComputationPostOrder( } } +enum State { kVisiting, kVisited }; + +void ComputeInstructionPostOrder( + std::vector* post_order, HloInstruction* root, + tensorflow::gtl::FlatMap* visited) { + std::vector dfs_stack; + dfs_stack.push_back(root); + while (!dfs_stack.empty()) { + const auto current = dfs_stack.back(); + auto it = visited->find(current); + if (it != visited->end()) { + if (it->second == kVisited) { + // Already visited. + dfs_stack.pop_back(); + continue; + } + // Visit this node. + CHECK_EQ(kVisiting, it->second); + dfs_stack.pop_back(); + post_order->push_back(current); + it->second = kVisited; + continue; + } + + visited->insert({current, kVisiting}); + + // Add the operands to the stack in reverse order so the first operand is + // processed first. This will produce a more natural ordering and a nicer + // result for thigns like HLO stringification. + const auto& operands = current->operands(); + for (int64 i = operands.size() - 1; i >= 0; --i) { + dfs_stack.emplace_back(operands[i]); + } + + for (HloInstruction* op : current->control_predecessors()) { + dfs_stack.emplace_back(op); + } + } +} + } // namespace -std::list HloComputation::MakeInstructionPostOrder() const { - std::list post_order; - std::list trace_instructions; - tensorflow::gtl::FlatSet added_instructions; +std::vector HloComputation::MakeInstructionPostOrder() const { + std::vector post_order; + post_order.reserve(instruction_count()); + std::vector trace_instructions; + tensorflow::gtl::FlatMap visited; for (auto& instruction : instructions_) { if (instruction->opcode() == HloOpcode::kTrace) { // Trace instructions aren't handled by the DFS visitor. Add trace @@ -328,21 +333,20 @@ std::list HloComputation::MakeInstructionPostOrder() const { // users). trace_instructions.push_back(instruction.get()); } else if (instruction->users().empty()) { - post_order.splice(post_order.end(), - InstructionPostOrderer::GetOrder(instruction.get(), - &added_instructions)); + ComputeInstructionPostOrder(&post_order, instruction.get(), &visited); } } - post_order.splice(post_order.end(), trace_instructions); + post_order.insert(post_order.end(), trace_instructions.begin(), + trace_instructions.end()); CHECK_EQ(instructions_.size(), post_order.size()) << "number of instructions does not match post order size"; return post_order; } -std::list HloComputation::MakeEmbeddedComputationsList() +std::vector HloComputation::MakeEmbeddedComputationsList() const { tensorflow::gtl::FlatSet visited; - std::list post_order; + std::vector post_order; // To avoid special handling of this computation, cast away const of // 'this'. 'this' is immediately removed from the post order after @@ -488,21 +492,7 @@ HloInstruction* HloComputation::CreateFusionInstruction( StatusOr HloComputation::DeepCopyHelper( HloInstruction* instruction, const ShapeTree* indices_to_copy, ShapeTree* copies_added, ShapeIndex* index) { - if (ShapeUtil::IsArray(instruction->shape())) { - if (indices_to_copy == nullptr || indices_to_copy->element(*index)) { - // Use kCopy to copy array elements - HloInstruction* copy = AddInstruction(HloInstruction::CreateUnary( - instruction->shape(), HloOpcode::kCopy, instruction)); - if (copies_added != nullptr) { - *copies_added->mutable_element(*index) = copy; - } - return copy; - } else { - // Array elements which are not to be copied are passed through - // transparently. - return instruction; - } - } else if (ShapeUtil::IsTuple(instruction->shape())) { + if (ShapeUtil::IsTuple(instruction->shape())) { std::vector elements; for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); i++) { @@ -519,9 +509,27 @@ StatusOr HloComputation::DeepCopyHelper( index->pop_back(); } return AddInstruction(HloInstruction::CreateTuple(elements)); + } + if (ShapeUtil::IsToken(instruction->shape())) { + // Tokens have no on-device representation and cannot be copied. Pass + // through transparently. + return instruction; + } + + // Array shape. + TF_RET_CHECK(ShapeUtil::IsArray(instruction->shape())); + if (indices_to_copy == nullptr || indices_to_copy->element(*index)) { + // Use kCopy to copy array elements + HloInstruction* copy = AddInstruction(HloInstruction::CreateUnary( + instruction->shape(), HloOpcode::kCopy, instruction)); + if (copies_added != nullptr) { + *copies_added->mutable_element(*index) = copy; + } + return copy; } else { - return FailedPrecondition( - "Can only copy array and tuple shaped instructions"); + // Elements which are not to be copied are passed through + // transparently. + return instruction; } } @@ -609,7 +617,7 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, std::unique_ptr HloComputation::ComputeReachability() const { - const std::list all = MakeInstructionPostOrder(); + const auto& all = MakeInstructionPostOrder(); auto result = MakeUnique(all); std::vector inputs; @@ -827,15 +835,6 @@ std::unique_ptr HloComputation::CloneWithReplacements( } } context->MapComputation(this, result.get()); - // We cloned the elements of 'replacements', so they're all going to be - // destroyed. HloInstructions need to be detached from their operands before - // they're destroyed, otherwise they stick around in the operands' users lists - // and cause use-after-frees. - for (auto& kv : replacements) { - if (std::unique_ptr& new_instr = kv.second) { - new_instr->DetachFromOperands(); - } - } return result; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 0da4a305f3d5d694a1918fed294337100b0a27fd..0f111a1a7672d419d32387d7fe0020744ba8ddf2 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -199,7 +199,7 @@ class HloComputation { // Compute and return a post-order of the instructions in the computation. In // this order, definitions of values always appear before their uses. - std::list MakeInstructionPostOrder() const; + std::vector MakeInstructionPostOrder() const; // Computes and returns the reachability between HLO instructions in the // computation. The returned HloReachabilityMap is constructed such that @@ -221,7 +221,7 @@ class HloComputation { // transitively. The embedded computations are sorted such that if computation // A calls computation B (eg, via a map instruction) then A will appear after // B in the list. - std::list MakeEmbeddedComputationsList() const; + std::vector MakeEmbeddedComputationsList() const; // Creates a fusion instruction containing the given instructions. // `fusion_kind` indicates the type of the fusion, e.g., loop fusion or fusion diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 25469a54c48f4f5cab478aba929f1cc18de8b81f..c504fc51d229ca70499bfe006ed9c350251d2c8a 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -371,6 +371,38 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { } } +TEST_F(HloComputationTest, DeepCopyToken) { + // Test that DeepCopyInstruction properly handles tokens which should not be + // copied. + auto builder = HloComputation::Builder(TestName()); + auto token = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + auto copy = computation->DeepCopyInstruction(token).ValueOrDie(); + + // No copy should be added. + EXPECT_THAT(copy, op::GenerateToken()); +} + +TEST_F(HloComputationTest, DeepCopyTokenTuple) { + // Test that DeepCopyInstruction properly handles tokens which should not be + // copied. + auto builder = HloComputation::Builder(TestName()); + auto token = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({token, constant})); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + auto copy = computation->DeepCopyInstruction(tuple).ValueOrDie(); + + // Only the array (second tuple element) should be copied. The token is passed + // through transparently. + EXPECT_THAT(copy, op::Tuple(op::GetTupleElement(tuple), + op::Copy(op::GetTupleElement(tuple)))); +} + TEST_F(HloComputationTest, CycleDetection) { // Test whether the visitor can detect cycles in the graph. auto builder = HloComputation::Builder(TestName()); @@ -385,6 +417,9 @@ TEST_F(HloComputationTest, CycleDetection) { // Add a control dependency to create a cycle. ASSERT_IS_OK(add->AddControlDependencyTo(negate)); + auto instructions = computation->MakeInstructionPostOrder(); + EXPECT_EQ(3, instructions.size()); + const auto visitor = [](HloInstruction* instruction) { return Status::OK(); }; auto visit_status = computation->Accept(visitor); ASSERT_FALSE(visit_status.ok()); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 94c9c7eabcc99d4cf61f535925c068a9b55ed136..762e1afc71b108b2e32b5a7f7f1bbeb783fc6fbd 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -172,15 +172,22 @@ Status HloCostAnalysis::HandleReverse(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleSlice(const HloInstruction*) { +Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) { + current_properties_[kBytesAccessedKey] = shape_size_(slice->shape()) * 2; return Status::OK(); } -Status HloCostAnalysis::HandleDynamicSlice(const HloInstruction*) { +Status HloCostAnalysis::HandleDynamicSlice( + const HloInstruction* dynamic_slice) { + current_properties_[kBytesAccessedKey] = + shape_size_(dynamic_slice->shape()) * 2; return Status::OK(); } -Status HloCostAnalysis::HandleDynamicUpdateSlice(const HloInstruction*) { +Status HloCostAnalysis::HandleDynamicUpdateSlice( + const HloInstruction* dynamic_update_slice) { + current_properties_[kBytesAccessedKey] = + shape_size_(dynamic_update_slice->operand(1)->shape()) * 2; return Status::OK(); } @@ -386,6 +393,10 @@ Status HloCostAnalysis::HandleTranspose(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleGenerateToken(const HloInstruction*) { + return Status::OK(); +} + Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { auto lhs = convolution->operand(0); auto rhs = convolution->operand(1); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index d17678d20f2a23fd98d18b77d5fb25853901a789..0d66736fe1d0677d13a63ede7a203d6ac20c76f5 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -97,6 +97,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleBroadcast(const HloInstruction* broadcast) override; Status HandlePad(const HloInstruction* pad) override; Status HandleReshape(const HloInstruction* reshape) override; + Status HandleGenerateToken(const HloInstruction* token) override; Status HandleTranspose(const HloInstruction* transpose) override; Status HandleWhile(const HloInstruction* xla_while) override; Status HandleConditional(const HloInstruction* conditional) override; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 16fdda8a8b9ade09ea31cda1f4cf5e8ff2c0a081..d22bef56730da194816b4ee89dc3196439b350f9 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -460,5 +460,51 @@ TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { EXPECT_EQ(analysis.flop_count(), 1472); } +TEST_F(HloCostAnalysisTest, Slice) { + // Test the analysis on a slice. + XlaBuilder builder("slice"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x"); + auto slice = builder.Slice(x, {0}, {1}, {1}); + auto hlo_module = BuildHloGraph(&builder); + + // Run HLO cost analysis. + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.bytes_accessed(), 8); +} + +TEST_F(HloCostAnalysisTest, DynamicSlice) { + // Test the analysis on a slice. + XlaBuilder builder("dynamic-slice"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x"); + auto slice = builder.DynamicSlice(x, builder.ConstantR1({1}), {1}); + auto hlo_module = BuildHloGraph(&builder); + + // Run HLO cost analysis. + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.bytes_accessed(), 8); +} + +TEST_F(HloCostAnalysisTest, DynamicUpdateSlice) { + // Test the analysis on a slice. + XlaBuilder builder("dynamic-update-slice"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x"); + auto slice = builder.DynamicUpdateSlice(x, builder.ConstantR1({1.0}), + builder.ConstantR1({1})); + auto hlo_module = BuildHloGraph(&builder); + + // Run HLO cost analysis. + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.bytes_accessed(), 8); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index dab946a099fa0066a4a0d42ce29077b9de6a486e..a0ee8896230d6dcacb5a8eb607fc00ae5226cfa5 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -135,17 +135,18 @@ StatusOr HloCSE::Run(HloModule* module) { // instruction for each class. tensorflow::gtl::FlatSet - representatives(/*N=*/1024, &CseHash, cse_equal); - + representatives(/*N=*/computation->instruction_count() + 1, &CseHash, + cse_equal); for (auto instruction : computation->MakeInstructionPostOrder()) { // If the instruction has zero operands (constants, parameters, etc.) skip // over it. if (instruction->operand_count() == 0) { continue; } - - // Skip instructions which have side effects. - if (instruction->HasSideEffect()) { + // Skip instructions which have side effects or are a domain (which must + // not be CSE-ed). + if (instruction->HasSideEffect() || + instruction->opcode() == HloOpcode::kDomain) { continue; } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index cc130a4900dc162d4b416116fbe879fec37136a2..d0200058683b2db8f5f0469d6c643014881f741e 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -931,16 +931,17 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( } const HloUse& use = value.uses()[0]; - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && - user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - // Loop fusion with kDynamicUpdateSlice fused root. - // - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root at operand - // index 0. - return use.instruction == user->fused_expression_root() && - use.operand_number == 0; + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + if (user->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + // Loop fusion with kDynamicUpdateSlice fused root. + // + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root at operand + // index 0. + return use.instruction == user->fused_expression_root() && + use.operand_number == 0; + } } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && user->fused_expression_root()->opcode() == HloOpcode::kAdd) { // Output fusion with kAdd fused root. @@ -967,6 +968,7 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( use.operand_number == other_add_operand_index; } } + if (user->opcode() == HloOpcode::kDynamicUpdateSlice || user->opcode() == HloOpcode::kWhile) { // We eliminated other users in BufferLiveness::live_range_strictly_before, @@ -998,8 +1000,13 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( }) != uses.end(); return uses.size() == 2 && found_caller_use && found_elementwise_callee_use; } - // Check if 'user' is element-wise. - return user->IsElementwise(); + + // Loop fusions that contain transposing copies won't reach here as they have + // different layouts, which fails the check in the beginning of this function. + // + // Multi-output fusion will fail the check here as tuples are not considered + // an elementwise operation. + return user->IsElementwiseOnOperand(user->operand_index(operand)); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 5798326dcbf65c3c34748afb02afab1dc7af9147..db1822ec47a7f52e2c3ef8dcbf433cd787ef75ab 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1974,6 +1974,89 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {})); } +TEST_F(CanShareOperandBufferWithUserTest, + NonElementwiseLoopFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "param0")); + + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, param0)); + + auto reverse = builder.AddInstruction( + HloInstruction::CreateReverse(data_shape, neg, {0, 1})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {reverse, neg}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, + MultiOutputFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + Shape in_shape = ShapeUtil::MakeShape(F32, {8}); + Shape out_shape = ShapeUtil::MakeShape(PRED, {8}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, in_shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, in_shape, "param1")); + + auto copy0 = builder.AddInstruction( + HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param0)); + auto copy1 = builder.AddInstruction( + HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param1)); + + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({copy1, copy0})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {tuple, copy1, copy0}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {0})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {1})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {0})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {1})); +} + +TEST_F(CanShareOperandBufferWithUserTest, + ElementwiseLoopFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, operand)); + + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(data_shape, HloOpcode::kExp, neg)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {exp, neg}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {}, + fusion, {})); +} + TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { auto builder = HloComputation::Builder(TestName()); @@ -2048,6 +2131,46 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { fusion, {})); } +TEST_F(CanShareOperandBufferWithUserTest, + FusedDynamicUpdateSliceWithConvertCantShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape data_shape_bf16 = ShapeUtil::MakeShape(BF16, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + auto convert1 = builder.AddInstruction( + HloInstruction::CreateConvert(data_shape_bf16, gte1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape_bf16, convert1, update, starts)); + + auto convert2 = builder.AddInstruction( + HloInstruction::CreateConvert(data_shape, dynamic_update_slice)); + builder.AddInstruction(HloInstruction::CreateTuple({gte0, convert2})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {convert2, dynamic_update_slice, starts, update, convert1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction can't share with tuple element 1. + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(gte1, {}, fusion, {})); +} + TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { auto builder = HloComputation::Builder(TestName()); diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index fcd723af146e2227b8661b1a4993f1338f7de389..8aa26bf520bbbc54a86acbc8e0c6be2e142df0dc 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -85,8 +85,7 @@ StatusOr HloDCE::Run(HloModule* module) { } // Remove dead computations. - std::list computations = module->MakeComputationPostOrder(); - for (auto* computation : computations) { + for (auto* computation : module->MakeComputationPostOrder()) { if (live_computations.count(computation) == 0) { TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(computation)); changed = true; diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h index e0c5718509dabebb7b9307bf764b0ea1ce7369a0..eded3e78eead76c4564daee119034c5031eba409 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h @@ -26,10 +26,10 @@ limitations under the License. namespace xla { // Domain isolation is the task of placing kDomain instructions between HLO -// instructions having different shrading. A kDomain instruction is essentially +// instructions having different sharding. A kDomain instruction is essentially // used to break an HLO graph edge connecting two instructions with different // sharding. If a set of connected instructions have all the same sharding, no -// kDomain instruciton will be placed. +// kDomain instruction will be placed. class HloDomainIsolator : public HloPassInterface { public: // Creates a new kDomain instruction for the edge between the use instruction diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index 5553ddb153f7f1f2e6a790890c11f35e192488c4..5d8081c1ef197548e1d802374f3efe35fa113cd3 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -21,12 +21,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { -class HloDomainTest : public HloTestBase { +class HloDomainTest : public HloVerifiedTestBase { protected: bool FindUserViaDomainPath(HloInstruction* instruction, HloInstruction* operand) const { @@ -64,11 +65,11 @@ class HloDomainTest : public HloTestBase { return false; } - StatusOr> ParseModule( - tensorflow::StringPiece hlo_string) { + StatusOr ParseModule(tensorflow::StringPiece hlo_string) { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - return ParseHloString(hlo_string, config); + ParseAndVerifyModule(hlo_string, config); + return &module(); } }; @@ -143,32 +144,31 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator(CreateShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); - EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); - EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); - EXPECT_TRUE(HasDomainEdge(module.get(), "d", "b")); - EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c")); - EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + EXPECT_TRUE(HasDomainEdge(module, "c", "a")); + EXPECT_TRUE(HasDomainEdge(module, "c", "b")); + EXPECT_TRUE(HasDomainEdge(module, "d", "a")); + EXPECT_TRUE(HasDomainEdge(module, "d", "b")); + EXPECT_FALSE(HasDomainEdge(module, "e", "c")); + EXPECT_FALSE(HasDomainEdge(module, "e", "d")); HloDomainRemover remover(ShardingMetadata::KindName(), NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); EXPECT_TRUE(remover_changed); - EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); - EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); - EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); - EXPECT_FALSE(HasDomainEdge(module.get(), "d", "b")); - EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c")); - EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + EXPECT_FALSE(HasDomainEdge(module, "c", "a")); + EXPECT_FALSE(HasDomainEdge(module, "c", "b")); + EXPECT_FALSE(HasDomainEdge(module, "d", "a")); + EXPECT_FALSE(HasDomainEdge(module, "d", "b")); + EXPECT_FALSE(HasDomainEdge(module, "e", "c")); + EXPECT_FALSE(HasDomainEdge(module, "e", "d")); } TEST_F(HloDomainTest, CheckNoDomainAddedIfNoSharding) { @@ -186,12 +186,11 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator(CreateShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(!isolator_changed); } @@ -212,27 +211,26 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator(CreateShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module.get(), "b", "a")); - EXPECT_TRUE(HasDomainEdge(module.get(), "f", "e")); - EXPECT_FALSE(HasDomainEdge(module.get(), "a", "p0")); - EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); - EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + EXPECT_TRUE(HasDomainEdge(module, "b", "a")); + EXPECT_TRUE(HasDomainEdge(module, "f", "e")); + EXPECT_FALSE(HasDomainEdge(module, "a", "p0")); + EXPECT_FALSE(HasDomainEdge(module, "c", "b")); + EXPECT_FALSE(HasDomainEdge(module, "e", "d")); HloDomainRemover remover(ShardingMetadata::KindName(), NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); EXPECT_TRUE(remover_changed); - EXPECT_FALSE(HasDomainEdge(module.get(), "b", "a")); - EXPECT_FALSE(HasDomainEdge(module.get(), "f", "e")); + EXPECT_FALSE(HasDomainEdge(module, "b", "a")); + EXPECT_FALSE(HasDomainEdge(module, "f", "e")); } TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) { @@ -248,12 +246,11 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator(CreateShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_FALSE(isolator_changed); } @@ -270,16 +267,15 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainRemover remover(ShardingMetadata::KindName(), NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); EXPECT_FALSE(remover_changed); - HloInstruction* add = FindInstruction(module.get(), "c"); + HloInstruction* add = FindInstruction(module, "c"); ASSERT_NE(add, nullptr); auto device = add->sharding_unique_device(); EXPECT_TRUE(device.has_value()); @@ -302,42 +298,41 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator sharding_isolator(CreateShardingDomain); TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed, - sharding_isolator.Run(module.get())); + sharding_isolator.Run(module)); EXPECT_TRUE(sharding_isolator_changed); HloDomainIsolator opname_isolator(OpNameDomainCreator); TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed, - opname_isolator.Run(module.get())); + opname_isolator.Run(module)); EXPECT_TRUE(opname_isolator_changed); - EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); - EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); - EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); - EXPECT_TRUE(HasDomainEdge(module.get(), "d", "c")); - EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + EXPECT_TRUE(HasDomainEdge(module, "c", "a")); + EXPECT_TRUE(HasDomainEdge(module, "c", "b")); + EXPECT_TRUE(HasDomainEdge(module, "d", "a")); + EXPECT_TRUE(HasDomainEdge(module, "d", "c")); + EXPECT_FALSE(HasDomainEdge(module, "e", "d")); HloDomainRemover sharding_remover(ShardingMetadata::KindName(), NormalizeShardingDomain); TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed, - sharding_remover.Run(module.get())); + sharding_remover.Run(module)); EXPECT_TRUE(sharding_remover_changed); HloDomainRemover opname_remover(OpNameMetadata::KindName(), OpNameDomainNormalizer); TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed, - opname_remover.Run(module.get())); + opname_remover.Run(module)); EXPECT_TRUE(opname_remover_changed); - EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); - EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); - EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); - EXPECT_FALSE(HasDomainEdge(module.get(), "d", "c")); + EXPECT_FALSE(HasDomainEdge(module, "c", "a")); + EXPECT_FALSE(HasDomainEdge(module, "c", "b")); + EXPECT_FALSE(HasDomainEdge(module, "d", "a")); + EXPECT_FALSE(HasDomainEdge(module, "d", "c")); } TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) { @@ -355,18 +350,17 @@ ENTRY entry { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); HloDomainIsolator isolator(CreateShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); - EXPECT_TRUE(HasDomainEdge(module.get(), "gte0", "infeed")); - EXPECT_TRUE(HasDomainEdge(module.get(), "gte1", "infeed")); - EXPECT_FALSE(HasDomainEdge(module.get(), "copy0", "gte0")); - EXPECT_FALSE(HasDomainEdge(module.get(), "copy1", "gte1")); + EXPECT_TRUE(HasDomainEdge(module, "gte0", "infeed")); + EXPECT_TRUE(HasDomainEdge(module, "gte1", "infeed")); + EXPECT_FALSE(HasDomainEdge(module, "copy0", "gte0")); + EXPECT_FALSE(HasDomainEdge(module, "copy1", "gte1")); // Inject unassigned tuple/gte within the infeed domain, to simulate the // HLO passes adding unexpected instructions. @@ -381,7 +375,7 @@ ENTRY entry { // TUPLE // | // DOMAIN - HloInstruction* infeed = FindInstruction(module.get(), "infeed"); + HloInstruction* infeed = FindInstruction(module, "infeed"); ASSERT_NE(infeed, nullptr); auto infeed_users = infeed->users(); HloInstruction* new_gte0 = @@ -404,7 +398,7 @@ ENTRY entry { HloDomainRemover remover(ShardingMetadata::KindName(), NormalizeShardingDomain); - TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); EXPECT_TRUE(remover_changed); struct Assignment { diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index abec29df433c521c3480b9297000085b1b1104e3..4ed1508d7067684a15d0fb7d86e69b055bc1333b 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -141,6 +141,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { // These are ops with embedded computations where it suffices to convert // the embedded computations instead of converting the ops themselves. if (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall || + opcode == HloOpcode::kCrossReplicaSum || opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap || opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kSelectAndScatter || diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 1e78d775c8e172a272a03fbd1101cef365e6dc2d..33424019b93feff862c6e3e268ae3980bacc9142 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -300,12 +300,6 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( instruction->CloneWithNewOperands(instruction->shape(), operands); auto result = Evaluate(cloned_instruction.get()); - // Clean up our cloned instructions before returning. - cloned_instruction->DetachFromOperands(); - for (auto& operand : owned_operands) { - operand->DetachFromOperands(); - } - return result; } @@ -321,7 +315,6 @@ StatusOr> HloEvaluator::EvaluateElementwiseBinaryOp( rhs_instr.get()); auto result = Evaluate(cloned_instruction.get()); - cloned_instruction->DetachFromOperands(); return result; } @@ -334,7 +327,6 @@ StatusOr> HloEvaluator::EvaluateElementwiseUnaryOp( HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get()); auto result = Evaluate(cloned_instruction.get()); - cloned_instruction->DetachFromOperands(); return result; } @@ -372,7 +364,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { // The result concatenate dimension is going to be the sum of all // concatenate dimensions of the operands taking part of the operation. const Shape& reference_shape = operands[0]->shape(); - CHECK(!ShapeUtil::IsTuple(reference_shape)); + CHECK(ShapeUtil::IsArray(reference_shape)); const int64 rank = ShapeUtil::Rank(reference_shape); const int64 concat_dim = concatenate->dimensions()[0]; CHECK_GE(concat_dim, 0); @@ -383,7 +375,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { for (int64 i = 1; i < operands.size(); ++i) { const Shape& operand_shape = operands[i]->shape(); - CHECK(!ShapeUtil::IsTuple(operand_shape)); + CHECK(ShapeUtil::IsArray(operand_shape)); // Accumulate the concat dimension from all tensors taking part to the // operation. concat_dimensions[concat_dim] += @@ -910,6 +902,11 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { return Status::OK(); } +Status HloEvaluator::HandleGenerateToken(HloInstruction* token) { + evaluated_[token] = Literal::CreateToken(); + return Status::OK(); +} + Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { const auto result_shape = get_tuple_element->shape(); const int64 index = get_tuple_element->tuple_index(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index b53d5644de5a17c52bdbf2593ce52f0227008a00..fc2fc9437b238a2e519401b2b121dfbef070e2dc 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -174,6 +174,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleBroadcast(HloInstruction* broadcast) override; + Status HandleGenerateToken(HloInstruction* token) override; + // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be // returned directly without looking up the cache. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 84b4ead2dd28caa40b6d7830a1e1401be88b6b36..72eb9930e92c340ab9f42cd563c27507623b2ba7 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -1248,7 +1248,7 @@ void BM_ReducePrecisely(int num_iters) { HloComputation::Builder b("BM_ReducePrecisely"); HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - HloModule module("BM_ReducePrecisely", VersionedComputationHandle(), config); + HloModule module("BM_ReducePrecisely", config); constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 std::vector v(kNumElements, 1.0f); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 13f46407e33e36bdbef4c9032630101d6c18268f..7e97eacf354ead688a57602ac39e9963250d197a 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -778,7 +778,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status HandleSelect(HloInstruction* select) override { CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape())); - CHECK(!ShapeUtil::IsTuple(select->shape())); + CHECK(ShapeUtil::IsArray(select->shape())); std::function select_op = [](bool pred, ReturnT on_true, ReturnT on_false) { if (pred) { @@ -1103,7 +1103,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } Status HandlePad(HloInstruction* pad) override { - CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape())); + CHECK(ShapeUtil::IsArray(pad->operand(0)->shape())); // Padding value must be scalar. CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape())); CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()), @@ -1116,7 +1116,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { /*padding_config=*/pad->padding_config())); CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanString(pad->shape()) - << "but is inferred to be: " + << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); // Create new HLO of padded shape with padding value. @@ -1182,7 +1182,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { dynamic_slice->dynamic_slice_sizes())); TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanString(result_shape) - << "but is inferred to be: " + << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); TF_RET_CHECK( primitive_util::IsIntegralType(start_indices->shape().element_type())); @@ -1237,7 +1237,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { operand->shape(), update->shape(), start_indices->shape())); TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanString(result_shape) - << "but is inferred to be: " + << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); TF_RET_CHECK( primitive_util::IsIntegralType(start_indices->shape().element_type())); @@ -1378,6 +1378,44 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleSort(HloInstruction* sort) { + TF_RET_CHECK(ShapeUtil::Rank(sort->shape()) == 1) + << "Sort is only supported for R1 shapes"; + + auto arg = sort->operand(0); + const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg); + VLOG(3) << "HandleSort arg_literal: " << arg_literal.ToString(); + const auto& arg_data = arg_literal.data(); + + std::vector return_data(arg_data.begin(), arg_data.end()); + std::sort(return_data.begin(), return_data.end(), + [](const ReturnT& a, const ReturnT& b) { + return SafeLess(a, b); + }); + auto result_literal = MakeUnique(sort->shape()); + result_literal->PopulateR1( + tensorflow::gtl::ArraySlice(return_data)); + VLOG(3) << "HandleSort result_literal: " << result_literal->ToString(); + parent_->evaluated_[sort] = std::move(result_literal); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleSort(HloInstruction* sort) { + return InvalidArgument("Unsupported type for Sort"); + } + + Status HandleSort(HloInstruction* sort) override { + return HandleSort(sort); + } + Status HandleReduce(HloInstruction* reduce) override { auto arg = reduce->operand(0); auto init_value = reduce->operand(1); @@ -1393,7 +1431,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { /*to_apply=*/function->ComputeProgramShape())); TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape()) - << "but is inferred to be: " + << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg); @@ -1613,7 +1651,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape)) << "return shape is set to: " << ShapeUtil::HumanStringWithLayout(reduce_window->shape()) - << "but is inferred to be: " + << " but is inferred to be: " << ShapeUtil::HumanStringWithLayout(inferred_return_shape); const Literal& operand_literal = @@ -2118,6 +2156,38 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return rhs_unsigned >= lhs_size_unsigned; } + // It's UB to use std::sort with std::less, because of NaNs. Define + // "safe" less functions which are actually strict weak orders. + template ::value>::type* = + nullptr> + static bool SafeLess(const NativeT& a, const NativeT& b) { + return a < b; + } + + template ::value || + std::is_same::value>::type* = nullptr> + static bool SafeLess(const NativeT& a, const NativeT& b) { + if (std::isnan(b)) { + return !std::isnan(a); + } else { + return a < b; + } + } + + template ::value>::type* = nullptr> + static bool SafeLess(const NativeT& a, const NativeT& b) { + if (Eigen::half_impl::isnan(b)) { + return !Eigen::half_impl::isnan(a); + } else { + return a < b; + } + } + HloEvaluator* parent_; }; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 61612bebd1e906d2d055e2f70de29da53275d4e8..ab224021c54fb3f5c5b69d0b633a080c304d5edd 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -28,6 +28,8 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -723,11 +725,25 @@ string HloDotDumper::DumpRootTag() { to_id, node_body, node_shape, NodeColorAttributes(color)); } +static const HloConstantInstruction* TryGetFusionParameterConstant( + const HloInstruction* instr) { + if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) { + return nullptr; + } + const HloInstruction* fusion = instr->parent()->FusionInstruction(); + const HloInstruction* operand = fusion->operand(instr->parameter_number()); + return DynCast(operand); +} + bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const { // If a node: // - // - is a tuple-shaped parameter, - // - is not a parameter to a fusion node, + // - is a parameter of a fusion node which is bound to a constant, + // + // or + // + // - is a tuple-shaped parameter, and + // - is not a parameter to a fusion node, and // - has at least kMinUsersToOmit users shown, and // - all of the shown users are get-tuple-elements, // @@ -735,6 +751,9 @@ bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const { // // This helps us handle the common case where a while loop body has one big // tuple-shaped parameter. + if (TryGetFusionParameterConstant(instr) != nullptr) { + return true; + } const int kMinUsersToOmit = 3; return instr->opcode() == HloOpcode::kParameter && ShapeUtil::IsTuple(instr->shape()) && !instr->IsFused() && @@ -806,26 +825,26 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { string HloDotDumper::GetInstructionNodeInlinedOperands( const HloInstruction* instr) { - auto stringify_constant = [](const HloInstruction* constant) { + auto stringify_constant = [](const HloConstantInstruction* constant) { const auto& shape = constant->shape(); // If the shape has a dimension of size zero, print it as e.g. // "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(), // enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which // is just noise. - if (!ShapeUtil::IsTuple(shape) && ShapeUtil::HasZeroElements(shape)) { + if (ShapeUtil::IsZeroElementArray(shape)) { return Printf("{} (%s)", ShapeUtil::HumanString(constant->shape())); } // Print the literal value of constants with <= K elements. optional elem_count; - if (!ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)) { + if (ShapeUtil::IsArray(shape)) { elem_count = 1; for (int64 dim : shape.dimensions()) { *elem_count *= dim; } } - if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) { + if (elem_count.has_value() && *elem_count <= 8) { return Printf("%s (%s)", constant->literal().ToString(), ShapeUtil::HumanString(constant->shape())); } @@ -841,29 +860,26 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( ShapeUtil::HumanString(constant->shape())); }; - // Special case: If instr is a parameter to a fusion node, check whether the - // corresponding operand to the fusion node is a constant. - if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) { - const HloInstruction* fusion = instr->parent()->FusionInstruction(); - const HloInstruction* operand = fusion->operand(instr->parameter_number()); - if (operand->opcode() != HloOpcode::kConstant) { - return ""; - } - return StrCat("constant ", stringify_constant(operand)); - } - std::vector lines; for (int64 i = 0; i < instr->operand_count(); ++i) { const HloInstruction* operand = instr->operand(i); + const auto* constant_operand = DynCast(operand); optional operand_str; - if (operand->opcode() == HloOpcode::kConstant) { - operand_str = stringify_constant(operand); + if (constant_operand != nullptr) { + operand_str = stringify_constant(constant_operand); } else if (ShouldMergeIntoUsers(operand)) { - // Special case: If the operand is a parameter, use its parameter number - // rather than its name, because that's generally how people think of the - // node. + // Special case: If the operand is a parameter to a fusion node and it + // always has a constant value, display it like a regular constant. + // + // For other parameters, use the parameter number rather than the proper + // name, because that's generally how people think of the node. if (operand->opcode() == HloOpcode::kParameter) { - operand_str = Printf("Parameter %lld", operand->parameter_number()); + if (const HloConstantInstruction* constant = + TryGetFusionParameterConstant(operand)) { + operand_str = stringify_constant(constant); + } else { + operand_str = Printf("Parameter %lld", operand->parameter_number()); + } } else { operand_str = operand->name(); } @@ -897,11 +913,14 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { const auto kParameterColor = kOrange; // Special case: If this instruction has a parameter merged into it, paint it - // the same color as a parameter. + // the same color as a parameter. Unless the merged-in parameter is a + // parameter to a fusion node that is bound to a constant -- these aren't + // "real" parameters from the user's perspective. if (std::any_of(instr->operands().begin(), instr->operands().end(), [&](const HloInstruction* operand) { return operand->opcode() == HloOpcode::kParameter && - ShouldMergeIntoUsers(operand); + ShouldMergeIntoUsers(operand) && + TryGetFusionParameterConstant(operand) == nullptr; })) { return kParameterColor; } @@ -964,6 +983,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kBitcast: case HloOpcode::kGetTupleElement: case HloOpcode::kTrace: + case HloOpcode::kGenerateToken: case HloOpcode::kTuple: return kWhite; case HloOpcode::kBroadcast: @@ -975,7 +995,6 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { } return kGreen; case HloOpcode::kConcatenate: - case HloOpcode::kCopy: case HloOpcode::kDynamicSlice: case HloOpcode::kGather: case HloOpcode::kPad: @@ -997,6 +1016,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { return kWhite; } return kGreen; + case HloOpcode::kCopy: + // Emphasize copy nodes, which are either physical transposes (and thus + // significant), or copies of read-only buffers (and thus dead weight). + return kGreen; case HloOpcode::kConvolution: case HloOpcode::kDot: case HloOpcode::kFft: diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 8e52d926d85f1ce6fabeb2dedd2f8e0fe0c2051d..68f41a1cbb4db228f5dcf8b4a6130f05e81262a8 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -121,7 +121,7 @@ TEST(HloGraphDumperTest, Constant) { HloComputation::Builder b("b"); auto instruction = b.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(-42))); - instruction->set_name("i_am_a_constant_root_instruction"); + instruction->SetAndSanitizeName("i_am_a_constant_root_instruction"); HloModuleConfig config; HloModule m(TestName(), config); HloComputation* root_computation = m.AddEntryComputation(b.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 1c276b9305d3edc2f575130bea2d1b8eac0af13c..2d496daab085e17dcf8b32f82b0444b0815d54e1 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include -#include #include #include #include @@ -27,14 +26,15 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" @@ -60,107 +60,334 @@ StatusOr> HloInstruction::CreateFromProto( TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); TF_RET_CHECK(proto.has_shape()); - auto instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); - for (const int64 operand_id : proto.operand_ids()) { - TF_RET_CHECK(ContainsKey(instruction_map, operand_id)) - << "No instruction with id " << operand_id; - instruction->AppendOperand(instruction_map.at(operand_id)); - } - for (const int64 predecessor_id : proto.control_predecessor_ids()) { - TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id)) - << "No instruction with id " << predecessor_id; - TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id) - ->AddControlDependencyTo(instruction.get())); - } - - // In the proto, fused computations are held exclusively within the - // HloInstructionProto and do not appear as an HloComputationProto within the - // HloModuleProto. - if (instruction->opcode() == HloOpcode::kFusion) { - TF_RET_CHECK(!proto.fusion_kind().empty()); - TF_ASSIGN_OR_RETURN(instruction->fusion_kind_, - StringToFusionKind(proto.fusion_kind())); - - // Find the fused computation and set its fusion instruction. - TF_RET_CHECK(proto.called_computation_ids_size() == 1) - << "Expect 1 called computation for fusion instruction, but sees " - << proto.called_computation_ids_size(); - const int64 fusion_id = proto.called_computation_ids(0); - auto* fused_computation = FindPtrOrNull(computation_map, fusion_id); - TF_RET_CHECK(fused_computation != nullptr) - << "No fusion computation with id " << fusion_id; - fused_computation->SetFusionInstruction(instruction.get()); - instruction->called_computations_.push_back(fused_computation); - } else { - for (const int64 computation_id : proto.called_computation_ids()) { - TF_RET_CHECK(ContainsKey(computation_map, computation_id)) - << "No computation with id " << computation_id; - instruction->called_computations_.push_back( - computation_map.at(computation_id)); + std::unique_ptr instruction; + const auto operands = [&instruction_map, &proto](int index) { + return instruction_map.at(proto.operand_ids(index)); + }; + const auto all_operands = [&instruction_map, &proto]() { + std::vector result(proto.operand_ids_size()); + std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), + result.begin(), [&instruction_map](int64 operand_id) { + return instruction_map.at(operand_id); + }); + return result; + }; + const auto computations = [&computation_map, &proto](int index) { + return computation_map.at(proto.called_computation_ids(index)); + }; + switch (opcode) { + // Ops migrated to subclasses. + case HloOpcode::kBatchNormTraining: + TF_RET_CHECK(proto.operand_ids_size() == 3) + << "BatchNormTraining instruction should have 3 operands but sees " + << proto.operand_ids_size(); + instruction = CreateBatchNormTraining( + proto.shape(), operands(0), operands(1), operands(2), proto.epsilon(), + proto.feature_index()); + break; + case HloOpcode::kBatchNormInference: + TF_RET_CHECK(proto.operand_ids_size() == 5) + << "BatchNormInference instruction should have 5 operands but sees " + << proto.operand_ids_size(); + instruction = CreateBatchNormInference( + proto.shape(), operands(0), operands(1), operands(2), operands(3), + operands(4), proto.epsilon(), proto.feature_index()); + break; + case HloOpcode::kBatchNormGrad: + TF_RET_CHECK(proto.operand_ids_size() == 5) + << "BatchNormGrad instruction should have 5 operands but sees " + << proto.operand_ids_size(); + instruction = CreateBatchNormGrad(proto.shape(), operands(0), operands(1), + operands(2), operands(3), operands(4), + proto.epsilon(), proto.feature_index()); + break; + case HloOpcode::kFft: { + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Fft instruction should have 1 operand but sees " + << proto.operand_ids_size(); + std::vector fft_length(proto.fft_length().begin(), + proto.fft_length().end()); + instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(), + tensorflow::gtl::ArraySlice(fft_length)); + break; + } + case HloOpcode::kSend: + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Send instruction should have 1 operand but sees " + << proto.operand_ids_size(); + instruction = CreateSend(operands(0), proto.channel_id()); + break; + case HloOpcode::kSendDone: + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "SendDone instruction should have 1 operand but sees " + << proto.operand_ids_size(); + instruction = CreateSendDone(operands(0)); + break; + case HloOpcode::kRecv: + TF_RET_CHECK(proto.operand_ids_size() == 0) + << "Recv instruction should have 0 operand but sees " + << proto.operand_ids_size(); + instruction = + CreateRecv(proto.shape().tuple_shapes(0), proto.channel_id()); + break; + case HloOpcode::kRecvDone: + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "RecvDone instruction should have 1 operand but sees " + << proto.operand_ids_size(); + instruction = CreateRecvDone(operands(0)); + break; + case HloOpcode::kReverse: + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Reverse instruction should have 1 operand but sees " + << proto.operand_ids_size(); + instruction = CreateReverse(proto.shape(), operands(0), + std::vector(proto.dimensions().begin(), + proto.dimensions().end())); + break; + case HloOpcode::kConcatenate: + TF_RET_CHECK(proto.dimensions_size() == 1) + << "Concatenate instruction should have 1 dimension but sees " + << proto.dimensions_size(); + instruction = + CreateConcatenate(proto.shape(), all_operands(), proto.dimensions(0)); + break; + case HloOpcode::kReduce: + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Reduce instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Reduce instruction should have 1 called computation but sees " + << proto.called_computation_ids_size(); + instruction = CreateReduce(proto.shape(), operands(0), operands(1), + std::vector(proto.dimensions().begin(), + proto.dimensions().end()), + computations(0)); + break; + case HloOpcode::kTranspose: + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Transpose instruction should have 1 operand but sees " + << proto.operand_ids_size(); + instruction = + CreateTranspose(proto.shape(), operands(0), + std::vector(proto.dimensions().begin(), + proto.dimensions().end())); + break; + case HloOpcode::kBroadcast: + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Broadcast instruction should have 1 operand but sees " + << proto.operand_ids_size(); + instruction = + CreateBroadcast(proto.shape(), operands(0), + std::vector(proto.dimensions().begin(), + proto.dimensions().end())); + break; + case HloOpcode::kMap: + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Map instruction should have 1 called computation but sees " + << proto.called_computation_ids_size(); + instruction = CreateMap(proto.shape(), all_operands(), computations(0)); + break; + case HloOpcode::kSlice: { + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Slice instruction should have 1 operand but sees " + << proto.operand_ids_size(); + std::vector slice_starts, slice_limits, slice_strides; + for (const HloInstructionProto::SliceDimensions& slice_dimensions : + proto.slice_dimensions()) { + slice_starts.push_back(slice_dimensions.start()); + slice_limits.push_back(slice_dimensions.limit()); + slice_strides.push_back(slice_dimensions.stride()); + } + instruction = CreateSlice(proto.shape(), operands(0), slice_starts, + slice_limits, slice_strides); + break; + } + case HloOpcode::kConstant: { + // TODO(b/110214922): Revert this to CHECK(proto.has_literal()). + if (proto.has_literal()) { + TF_ASSIGN_OR_RETURN(auto literal, + Literal::CreateFromProto(proto.literal())); + instruction = CreateConstant(std::move(literal)); + } else { + instruction = MakeUnique(proto.shape()); + } + break; + } + case HloOpcode::kTrace: { + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Trace instruction should have 1 operand but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.has_literal()); + TF_ASSIGN_OR_RETURN(auto literal, + Literal::CreateFromProto(proto.literal())); + instruction = CreateTrace(literal->GetR1U8AsString(), operands(0)); + break; + } + case HloOpcode::kFusion: { + // In the proto, fused computations are held exclusively within the + // HloInstructionProto and do not appear as an HloComputationProto within + // the HloModuleProto. + TF_RET_CHECK(!proto.fusion_kind().empty()); + TF_ASSIGN_OR_RETURN(FusionKind fusion_kind, + StringToFusionKind(proto.fusion_kind())); + + // Find the fused computation and set its fusion instruction. + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Expect 1 called computation for fusion instruction but sees " + << proto.called_computation_ids_size(); + const int64 fusion_id = proto.called_computation_ids(0); + auto* fused_computation = FindPtrOrNull(computation_map, fusion_id); + TF_RET_CHECK(fused_computation != nullptr) + << "No fusion computation with id " << fusion_id; + instruction = CreateFusion(proto.shape(), fusion_kind, all_operands(), + fused_computation); + break; + } + case HloOpcode::kRng: + instruction = + CreateRng(proto.shape(), proto.distribution(), all_operands()); + break; + case HloOpcode::kParameter: + instruction = CreateParameter(proto.parameter_number(), proto.shape(), + proto.name()); + break; + case HloOpcode::kGetTupleElement: + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "GetTupleElement instruction should have 1 operand but sees " + << proto.operand_ids_size(); + instruction = CreateGetTupleElement(proto.shape(), operands(0), + proto.tuple_index()); + break; + case HloOpcode::kReducePrecision: + instruction = + CreateReducePrecision(proto.shape(), operands(0), + proto.exponent_bits(), proto.mantissa_bits()); + break; + case HloOpcode::kInfeed: + instruction = CreateInfeed(proto.shape(), proto.infeed_config()); + break; + case HloOpcode::kOutfeed: + instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), + proto.outfeed_config()); + break; + case HloOpcode::kCrossReplicaSum: { + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "CrossReplicaSum should have 1 called computation but sees " + << proto.called_computation_ids_size(); + tensorflow::gtl::optional all_reduce_id; + if (proto.all_reduce_id() > 0) { + all_reduce_id = proto.all_reduce_id(); + } + instruction = CreateCrossReplicaSum( + proto.shape(), all_operands(), computations(0), + /*replica_group_ids=*/ + std::vector(proto.replica_group_ids().begin(), + proto.replica_group_ids().end()), + /*barrier=*/proto.cross_replica_sum_barrier(), + /*all_reduce_id=*/all_reduce_id); + break; + } + case HloOpcode::kConvolution: + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Convolution instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.has_window()); + TF_RET_CHECK(proto.has_convolution_dimension_numbers()); + instruction = + CreateConvolve(proto.shape(), operands(0), operands(1), + proto.window(), proto.convolution_dimension_numbers()); + break; + case HloOpcode::kReduceWindow: + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "ReduceWindow instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "ReduceWindow should have 1 called computation but sees " + << proto.called_computation_ids_size(); + instruction = CreateReduceWindow(proto.shape(), operands(0), operands(1), + proto.window(), computations(0)); + break; + case HloOpcode::kSelectAndScatter: + TF_RET_CHECK(proto.operand_ids_size() == 3) + << "SelectAndScatter instruction should have 3 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.called_computation_ids_size() == 2) + << "SelectAndScatter should have 2 called computations but sees " + << proto.called_computation_ids_size(); + instruction = CreateSelectAndScatter( + proto.shape(), operands(0), computations(0), proto.window(), + operands(1), operands(2), computations(1)); + break; + case HloOpcode::kCustomCall: + instruction = CreateCustomCall(proto.shape(), all_operands(), + proto.custom_call_target()); + if (proto.has_window()) { + static_cast(instruction.get()) + ->set_window(proto.window()); + } + if (proto.has_convolution_dimension_numbers()) { + static_cast(instruction.get()) + ->set_convolution_dimension_numbers( + proto.convolution_dimension_numbers()); + } + break; + case HloOpcode::kHostCompute: + instruction = + CreateHostCompute(proto.shape(), all_operands(), proto.channel_name(), + proto.cost_estimate_ns()); + break; + case HloOpcode::kPad: + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Pad instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.has_padding_config()); + instruction = CreatePad(proto.shape(), operands(0), operands(1), + proto.padding_config()); + break; + case HloOpcode::kDynamicSlice: { + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "DynamicSlice instruction should have 2 operands but sees " + << proto.operand_ids_size(); + std::vector slice_sizes(proto.dynamic_slice_sizes_size()); + c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin()); + instruction = CreateDynamicSlice(proto.shape(), operands(0), operands(1), + slice_sizes); + break; + } + default: { + instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); + for (const int64 operand_id : proto.operand_ids()) { + TF_RET_CHECK(ContainsKey(instruction_map, operand_id)) + << "No instruction with id " << operand_id; + instruction->AppendOperand(instruction_map.at(operand_id)); + } + for (const int64 predecessor_id : proto.control_predecessor_ids()) { + TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id)) + << "No instruction with id " << predecessor_id; + TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id) + ->AddControlDependencyTo(instruction.get())); + } + if (instruction->opcode() != HloOpcode::kFusion) { + for (const int64 computation_id : proto.called_computation_ids()) { + TF_RET_CHECK(ContainsKey(computation_map, computation_id)) + << "No computation with id " << computation_id; + instruction->called_computations_.push_back( + computation_map.at(computation_id)); + } + } + break; } - } - - if (instruction->opcode() == HloOpcode::kTrace) { - TF_RET_CHECK(instruction->operands().size() == 1) - << "Trace instruction should have 1 operand but sees " - << instruction->operands().size(); - instruction->mutable_operand(0)->set_tracing(instruction.get()); } TF_RET_CHECK(!proto.name().empty()); - instruction->name_ = proto.name(); - + instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); instruction->backend_config_ = proto.backend_config(); - if (proto.has_literal()) { - TF_ASSIGN_OR_RETURN(instruction->literal_, - Literal::CreateFromProto(proto.literal())); - } - instruction->parameter_number_ = proto.parameter_number(); - instruction->tuple_index_ = proto.tuple_index(); - for (int64 dimension : proto.dimensions()) { - instruction->dimensions_.push_back(dimension); - } - if (proto.has_window()) { - instruction->window_ = MakeUnique(proto.window()); - } - if (proto.has_convolution_dimension_numbers()) { - instruction->convolution_dimension_numbers_ = - MakeUnique( - proto.convolution_dimension_numbers()); - } if (proto.has_dot_dimension_numbers()) { instruction->dot_dimension_numbers_ = MakeUnique(proto.dot_dimension_numbers()); } - for (const HloInstructionProto::SliceDimensions& slice_dimensions : - proto.slice_dimensions()) { - instruction->slice_starts_.push_back(slice_dimensions.start()); - instruction->slice_limits_.push_back(slice_dimensions.limit()); - instruction->slice_strides_.push_back(slice_dimensions.stride()); - } - instruction->exponent_bits_ = proto.exponent_bits(); - instruction->mantissa_bits_ = proto.mantissa_bits(); - for (int64 dynamic_slice_size : proto.dynamic_slice_sizes()) { - instruction->dynamic_slice_sizes_.push_back(dynamic_slice_size); - } - if (proto.has_padding_config()) { - instruction->padding_config_ = - MakeUnique(proto.padding_config()); - } - instruction->outfeed_config_ = proto.outfeed_config(); - instruction->distribution_ = proto.distribution(); - instruction->epsilon_ = proto.epsilon(); - instruction->feature_index_ = proto.feature_index(); - instruction->channel_id_ = proto.channel_id(); - instruction->infeed_config_ = proto.infeed_config(); - instruction->custom_call_target_ = proto.custom_call_target(); - instruction->outfeed_shape_ = proto.outfeed_shape(); - instruction->fft_type_ = proto.fft_type(); - for (int64 fft_len : proto.fft_length()) { - instruction->fft_length_.push_back(fft_len); - } if (proto.has_sharding()) { TF_ASSIGN_OR_RETURN(const auto& sharding, @@ -175,61 +402,34 @@ StatusOr> HloInstruction::CreateFromProto( for (int64 bound : proto.gather_window_bounds()) { instruction->gather_window_bounds_.push_back(bound); } - - instruction->channel_name_ = proto.channel_name(); - instruction->cost_estimate_ns_ = proto.cost_estimate_ns(); - return std::move(instruction); } /* static */ std::unique_ptr HloInstruction::CreateParameter( int64 parameter_number, const Shape& shape, const string& name) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kParameter, shape)); - instruction->parameter_number_ = parameter_number; - instruction->name_ = name; - return instruction; + return MakeUnique(parameter_number, shape, name); } /* static */ std::unique_ptr HloInstruction::CreateTrace( const string& tag, HloInstruction* operand) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil())); - instruction->operands_.push_back(operand); - instruction->literal_ = Literal::CreateR1U8(tag); - operand->set_tracing(instruction.get()); - return instruction; + return MakeUnique(tag, operand); } /* static */ std::unique_ptr HloInstruction::CreateConstant( std::unique_ptr literal) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kConstant, literal->shape())); - instruction->literal_ = std::move(literal); - return instruction; + return MakeUnique(std::move(literal)); } /* static */ std::unique_ptr HloInstruction::CreateGetTupleElement(const Shape& shape, HloInstruction* operand, int64 index) { - CHECK(ShapeUtil::IsTuple(operand->shape())); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kGetTupleElement, shape)); - instruction->tuple_index_ = index; - instruction->AppendOperand(operand); - return instruction; + return MakeUnique(shape, operand, index); } /* static */ std::unique_ptr HloInstruction::CreateRng( const Shape& shape, RandomDistribution distribution, tensorflow::gtl::ArraySlice parameters) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kRng, shape)); - instruction->distribution_ = distribution; - instruction->shape_ = shape; - for (HloInstruction* param : parameters) { - instruction->AppendOperand(param); - } - return instruction; + return MakeUnique(shape, distribution, parameters); } /* static */ std::unique_ptr HloInstruction::CreateNary( @@ -344,43 +544,22 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, const Shape& shape, tensorflow::gtl::ArraySlice operands, HloComputation* map_computation, tensorflow::gtl::ArraySlice static_operands) { - CHECK(static_operands.empty()) << "static_operands not yet supported"; - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kMap, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->called_computations_.push_back(map_computation); - return instruction; + return MakeUnique(shape, operands, map_computation, + static_operands); } /* static */ std::unique_ptr HloInstruction::CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kConvolution, shape)); - if (window_util::HasBaseDilation(window)) { - instruction->name_ = instruction->name() + "-base-dilated"; - } - if (window_util::HasWindowDilation(window)) { - instruction->name_ = instruction->name() + "-window-dilated"; - } - instruction->AppendOperand(lhs); - instruction->AppendOperand(rhs); - instruction->window_ = MakeUnique(window); - instruction->convolution_dimension_numbers_ = - MakeUnique(dimension_numbers); - return instruction; + return MakeUnique(shape, lhs, rhs, window, + dimension_numbers); } /* static */ std::unique_ptr HloInstruction::CreateFft( const Shape& shape, HloInstruction* operand, FftType fft_type, tensorflow::gtl::ArraySlice fft_length) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFft, shape)); - instruction->AppendOperand(operand); - instruction->fft_type_ = fft_type; - instruction->fft_length_.assign(fft_length.begin(), fft_length.end()); - return instruction; + return MakeUnique(shape, operand, fft_type, fft_length); } /* static */ std::unique_ptr HloInstruction::CreateDot( @@ -413,93 +592,73 @@ HloInstruction::CreateReducePrecision(const Shape& shape, HloInstruction* operand, const int exponent_bits, const int mantissa_bits) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kReducePrecision, shape)); - instruction->AppendOperand(operand); - instruction->exponent_bits_ = exponent_bits; - instruction->mantissa_bits_ = mantissa_bits; - return instruction; + return MakeUnique( + shape, operand, exponent_bits, mantissa_bits); } /* static */ std::unique_ptr HloInstruction::CreateCrossReplicaSum( - const Shape& shape, tensorflow::gtl::ArraySlice operands) { - return CreateNary(shape, HloOpcode::kCrossReplicaSum, operands); + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* reduce_computation, + tensorflow::gtl::ArraySlice replica_group_ids, + tensorflow::StringPiece barrier, + const tensorflow::gtl::optional& all_reduce_id) { + return MakeUnique( + shape, operands, reduce_computation, replica_group_ids, barrier, + all_reduce_id); } /* static */ std::unique_ptr HloInstruction::CreateInfeed( const Shape& shape, const string& config) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kInfeed, shape)); - instruction->set_infeed_config(config); - return instruction; + return MakeUnique(shape, config); } /* static */ std::unique_ptr HloInstruction::CreateOutfeed( const Shape& shape, HloInstruction* operand, tensorflow::StringPiece outfeed_config) { - std::unique_ptr instruction = - WrapUnique(new HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeNil())); - CHECK(ShapeUtil::Compatible(operand->shape(), shape)) - << "Outfeed shape " << shape << " must be compatible with operand shape " - << operand->shape(); - instruction->AppendOperand(operand); - instruction->outfeed_config_ = std::string(outfeed_config); - instruction->outfeed_shape_ = shape; - return instruction; + return MakeUnique(shape, operand, outfeed_config); } /* static */ std::unique_ptr HloInstruction::CreateSend( HloInstruction* operand, int64 channel_id) { - // Send instruction produces a tuple of {aliased operand, U32 context}. - Shape output_shape = ShapeUtil::MakeTupleShape( - {operand->shape(), ShapeUtil::MakeShape(U32, {})}); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kSend, output_shape)); - instruction->AppendOperand(operand); - instruction->channel_id_ = channel_id; - return instruction; + return MakeUnique(operand, channel_id); } /* static */ std::unique_ptr HloInstruction::CreateSendDone( HloInstruction* operand) { - CHECK(operand->opcode() == HloOpcode::kSend) + auto send_operand = DynCast(operand); + CHECK(send_operand != nullptr) << "SendDone must take the context operand from Send"; - auto instruction = WrapUnique( - new HloInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil())); - instruction->AppendOperand(operand); - instruction->channel_id_ = operand->channel_id(); - return instruction; + return MakeUnique(send_operand); } /* static */ std::unique_ptr HloInstruction::CreateRecv( const Shape& shape, int64 channel_id) { - // Recv instruction produces a tuple of {receive buffer, U32 context}. - Shape output_shape = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kRecv, output_shape)); - instruction->channel_id_ = channel_id; - return instruction; + return MakeUnique(shape, channel_id); } /* static */ std::unique_ptr HloInstruction::CreateRecvDone( HloInstruction* operand) { - CHECK(operand->opcode() == HloOpcode::kRecv) + auto recv_operand = DynCast(operand); + CHECK(recv_operand != nullptr) << "RecvDone must take the context operand from Recv"; - Shape output_shape = ShapeUtil::GetTupleElementShape(operand->shape(), 0); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kRecvDone, output_shape)); - instruction->AppendOperand(operand); - instruction->channel_id_ = operand->channel_id(); - return instruction; + return MakeUnique(recv_operand); } /* static */ std::unique_ptr HloInstruction::CreateReverse( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReverse, shape)); - instruction->AppendOperand(operand); - instruction->dimensions_.assign(dimensions.begin(), dimensions.end()); + return MakeUnique(shape, operand, dimensions); +} + +/* static */ std::unique_ptr +HloInstruction::CreateGenerateToken( + tensorflow::gtl::ArraySlice operands) { + auto instruction = WrapUnique(new HloInstruction( + HloOpcode::kGenerateToken, ShapeUtil::MakeTokenShape())); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } return instruction; } @@ -536,30 +695,15 @@ HloInstruction::CreateCrossReplicaSum( tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kSlice, shape)); - instruction->AppendOperand(operand); - instruction->slice_starts_.assign(start_indices.begin(), start_indices.end()); - instruction->slice_limits_.assign(limit_indices.begin(), limit_indices.end()); - instruction->slice_strides_.assign(strides.begin(), strides.end()); - // For backward compatibility with old serialized computations: if there are - // no strides, assume all strides are 1. - // TODO(b/63317920): remove this code. - if (instruction->slice_strides_.empty()) { - instruction->slice_strides_ = std::vector(start_indices.size(), 1LL); - } - return instruction; + return MakeUnique(shape, operand, start_indices, + limit_indices, strides); } /* static */ std::unique_ptr HloInstruction::CreateDynamicSlice( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, tensorflow::gtl::ArraySlice slice_sizes) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kDynamicSlice, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(start_indices); - instruction->dynamic_slice_sizes_.assign(slice_sizes.begin(), - slice_sizes.end()); - return instruction; + return MakeUnique(shape, operand, start_indices, + slice_sizes); } /* static */ std::unique_ptr @@ -578,13 +722,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateConcatenate( const Shape& shape, tensorflow::gtl::ArraySlice operands, int64 dimension) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kConcatenate, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->dimensions_.push_back(dimension); - return instruction; + return MakeUnique(shape, operands, dimension); } /* static */ std::unique_ptr HloInstruction::CreateConvert( @@ -607,25 +745,15 @@ HloInstruction::CreateBitcastConvert(const Shape& shape, const Shape& shape, HloInstruction* arg, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reduce_computation) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReduce, shape)); - instruction->AppendOperand(arg); - instruction->AppendOperand(init_value); - instruction->dimensions_.assign(dimensions_to_reduce.begin(), - dimensions_to_reduce.end()); - instruction->called_computations_.push_back(reduce_computation); - return instruction; + return MakeUnique( + shape, arg, init_value, dimensions_to_reduce, reduce_computation); } /* static */ std::unique_ptr HloInstruction::CreateReduceWindow( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, const Window& window, HloComputation* reduce_computation) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kReduceWindow, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(init_value); - instruction->called_computations_.push_back(reduce_computation); - instruction->window_ = MakeUnique(window); - return instruction; + return MakeUnique(shape, operand, init_value, + window, reduce_computation); } /* static */ std::unique_ptr @@ -634,14 +762,8 @@ HloInstruction::CreateBatchNormTraining(const Shape& shape, HloInstruction* scale, HloInstruction* offset, float epsilon, int64 feature_index) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBatchNormTraining, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(scale); - instruction->AppendOperand(offset); - instruction->epsilon_ = epsilon; - instruction->feature_index_ = feature_index; - return instruction; + return MakeUnique( + shape, operand, scale, offset, epsilon, feature_index); } /* static */ std::unique_ptr @@ -649,16 +771,8 @@ HloInstruction::CreateBatchNormInference( const Shape& shape, HloInstruction* operand, HloInstruction* scale, HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, float epsilon, int64 feature_index) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBatchNormInference, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(scale); - instruction->AppendOperand(offset); - instruction->AppendOperand(mean); - instruction->AppendOperand(variance); - instruction->epsilon_ = epsilon; - instruction->feature_index_ = feature_index; - return instruction; + return MakeUnique( + shape, operand, scale, offset, mean, variance, epsilon, feature_index); } /* static */ std::unique_ptr @@ -667,16 +781,9 @@ HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand, HloInstruction* variance, HloInstruction* grad_output, float epsilon, int64 feature_index) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBatchNormGrad, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(scale); - instruction->AppendOperand(mean); - instruction->AppendOperand(variance); - instruction->AppendOperand(grad_output); - instruction->epsilon_ = epsilon; - instruction->feature_index_ = feature_index; - return instruction; + return MakeUnique(shape, operand, scale, mean, + variance, grad_output, epsilon, + feature_index); } /* static */ std::unique_ptr @@ -684,27 +791,15 @@ HloInstruction::CreateSelectAndScatter( const Shape& shape, HloInstruction* operand, HloComputation* select, const Window& window, HloInstruction* source, HloInstruction* init_value, HloComputation* scatter) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kSelectAndScatter, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(source); - instruction->AppendOperand(init_value); - // Select comes before scatter in the vector. - instruction->called_computations_.push_back(select); - instruction->called_computations_.push_back(scatter); - instruction->window_ = MakeUnique(window); - return instruction; + return MakeUnique( + shape, operand, select, window, source, init_value, scatter); } /* static */ std::unique_ptr HloInstruction::CreateBroadcast( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice broadcast_dimensions) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBroadcast, shape)); - instruction->AppendOperand(operand); - instruction->dimensions_.assign(broadcast_dimensions.begin(), - broadcast_dimensions.end()); - return instruction; + return MakeUnique(shape, operand, + broadcast_dimensions); } /* static */ std::unique_ptr @@ -762,11 +857,8 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreatePad( const Shape& shape, HloInstruction* operand, HloInstruction* padding_value, const PaddingConfig& padding_config) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kPad, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(padding_value); - instruction->padding_config_ = MakeUnique(padding_config); - return instruction; + return MakeUnique(shape, operand, padding_value, + padding_config); } /* static */ std::unique_ptr HloInstruction::CreateReshape( @@ -783,53 +875,28 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreateTranspose( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions) { - CHECK_EQ(shape.dimensions().size(), dimensions.size()); - CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size()); - CHECK(std::equal(operand->shape().dimensions().begin(), - operand->shape().dimensions().end(), - Permute(dimensions, shape.dimensions()).begin())) - << "shape: " << ShapeUtil::HumanString(shape) - << ", operand->shape(): " << ShapeUtil::HumanString(shape) - << ", dimensions: {" << Join(dimensions, ", ") << "}"; - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kTranspose, shape)); - instruction->AppendOperand(operand); - instruction->dimensions_.assign(dimensions.begin(), dimensions.end()); - return instruction; + return MakeUnique(shape, operand, dimensions); } /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); - instruction->fusion_kind_ = fusion_kind; - instruction->name_ = "fusion"; - instruction->set_parent(fused_root->parent()); - instruction->set_metadata(fused_root->metadata()); - instruction->CloneAndFuseInternal(fused_root); - return instruction; + return MakeUnique(shape, fusion_kind, fused_root); } /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, tensorflow::gtl::ArraySlice operands, HloComputation* fusion_computation) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->fusion_kind_ = fusion_kind; - instruction->name_ = "fusion"; - instruction->called_computations_.push_back(fusion_computation); - fusion_computation->SetFusionInstruction(instruction.get()); - return instruction; + return MakeUnique(shape, fusion_kind, operands, + fusion_computation); } -void HloInstruction::set_device_sharding(int64 device) { - HloSharding device_sharding = HloSharding::AssignDevice(device); +void HloInstruction::set_single_sharding(const HloSharding& sharding) { + CHECK(!sharding.IsTuple()) << sharding; if (ShapeUtil::IsTuple(shape())) { - set_sharding(HloSharding::Tuple(device_sharding.GetAsShapeTree(shape()))); + set_sharding(HloSharding::Tuple(sharding.GetAsShapeTree(shape()))); } else { - set_sharding(device_sharding); + set_sharding(sharding); } } @@ -843,289 +910,6 @@ void HloInstruction::SetupDerivedInstruction( derived_instruction->set_metadata(metadata_); } -HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) { - CHECK_EQ(opcode(), HloOpcode::kFusion); - CHECK_EQ(operand_count(), - fused_instructions_computation()->parameter_instructions().size()); - const int64 param_no = operand_count(); - // Name the parameter after the instruction it represents in the outer - // (non-fusion) computation. - string param_name = StrCat(new_operand->name(), ".param_", param_no); - HloInstruction* fused_parameter = - fused_instructions_computation()->AddParameter( - HloInstruction::CreateParameter(param_no, new_operand->shape(), - param_name)); - AppendOperand(new_operand); - return fused_parameter; -} - -void HloInstruction::MergeFusionInstruction( - HloInstruction* instruction_to_merge) { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK_EQ(instruction_to_merge->opcode(), HloOpcode::kFusion); - CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) != - operands().end()); - // Clone the instruction from which to merge fused instructions. - std::unique_ptr clone = instruction_to_merge->Clone(); - // Replace uses of fused parameters with the corresponding operand of the - // fusion. Add all non-parameter fused instructions to 'unfused_instructions' - // to be merged into 'this'. This is done in reverse post order. - std::vector unfused_instructions; - auto fused_instructions = - clone->fused_instructions_computation()->MakeInstructionPostOrder(); - for (auto fused_it = fused_instructions.rbegin(); - fused_it != fused_instructions.rend(); ++fused_it) { - auto fused_instruction = *fused_it; - if (fused_instruction->opcode() == HloOpcode::kParameter) { - TF_CHECK_OK(fused_instruction->ReplaceAllUsesWith( - clone->mutable_operand(fused_instruction->parameter_number()))); - } else { - unfused_instructions.push_back(fused_instruction); - } - } - CHECK(unfused_instructions.front() == clone->fused_expression_root()); - // Replace instruction_to_merge use of 'this' with unfused_root. - TF_CHECK_OK( - instruction_to_merge->ReplaceUseWith(this, unfused_instructions.front())); - // Fuse 'unfused_instructions' into 'this'. - for (auto& instruction : unfused_instructions) { - FuseInstruction(instruction); - instruction->DetachFromOperands(); - } - CHECK_EQ(0, clone->user_count()); - clone->DetachFromOperands(); - TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation( - clone->fused_instructions_computation())); -} - -void HloInstruction::MergeFusionInstructionIntoMultiOutput( - HloInstruction* instruction_to_merge) { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK_EQ(instruction_to_merge->opcode(), HloOpcode::kFusion); - // Add all non-parameter fused instructions to 'unfused_instructions' to be - // merged into 'this'. `old_to_new' maps the instructions in the fused node - // to the disaseembled fusion instructions. - // Note that we add the unfused instructions to this->parent_ computation. - // This is necessary because the unique_id needs for an instruction and - // it's only added when inserting to the computation. - tensorflow::gtl::FlatMap old_to_new; - std::vector unfused_instructions; - auto computation_to_merge = - instruction_to_merge->fused_instructions_computation(); - auto post_order = computation_to_merge->MakeInstructionPostOrder(); - for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) { - auto fused_instruction = *rit; - if (fused_instruction->opcode() == HloOpcode::kParameter) { - InsertOrDie(&old_to_new, fused_instruction, - instruction_to_merge->mutable_operand( - fused_instruction->parameter_number())); - continue; - } - - // Here we clone the insertion and call FuseInstructionIntoMultiOutput() - // which clones again. This can be improved. - auto cloned_instruction = - parent_->AddInstruction(fused_instruction->Clone()); - unfused_instructions.push_back(cloned_instruction); - InsertOrDie(&old_to_new, fused_instruction, cloned_instruction); - } - for (auto unfused_instruction : unfused_instructions) { - for (int64 index = 0; index < unfused_instruction->operand_count(); - index++) { - auto new_operand = - FindOrDie(old_to_new, unfused_instruction->mutable_operand(index)); - TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand)); - } - } - - HloInstruction* unfused_root = unfused_instructions.front(); - TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root)); - - TF_CHECK_OK( - instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge)); - if (GetModule()) { - TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge)); - } - - // Fuse the root instruction and generate multiple outputs. - FuseInstructionIntoMultiOutput(unfused_root); - TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root)); - // The rest instructions are of normal fusing. - for (int64 i = 1; i < unfused_instructions.size(); i++) { - auto instruction = unfused_instructions[i]; - FuseInstruction(instruction); - TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction)); - } -} - -HloInstruction* HloInstruction::FuseInstructionInternal( - HloInstruction* instruction_to_fuse, bool add_output) { - CHECK_EQ(opcode_, HloOpcode::kFusion); - - // When add_output is false, this fusion instruction must be a user of - // instruction_to_fuse. - if (!add_output) { - CHECK(IsUserOf(instruction_to_fuse)); - } - HloInstruction* fused_instruction = - CloneAndFuseInternal(instruction_to_fuse, add_output); - return fused_instruction; -} - -HloInstruction* HloInstruction::CloneAndFuseInternal( - HloInstruction* instruction_to_fuse, bool add_output) { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString(); - VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString(); - HloInstruction* clone = nullptr; - if (called_computations_.empty()) { - // New fusion instruction. It should not be a multioutput instruction. - CHECK(!add_output); - auto builder = HloComputation::Builder("fused_computation", this); - builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/"")); - called_computations_.push_back( - CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build())); - clone = fused_expression_root(); - } else { - clone = fused_instructions_computation()->AddInstruction( - instruction_to_fuse->Clone(/*suffix=*/"")); - // When add_output is false, instruction_to_fuse is necessarily an operand - // of the fusion instruction. After fusion this will no longer be the case. - // Remove the operand from the operand list and remove its corresponding - // fused parameter instruction. Renumber parameters as necessary to make - // parameter numbers consistent with their index in the - // fused_parameter_ vector. - bool in_operand_list = std::find(operands_.begin(), operands_.end(), - instruction_to_fuse) != operands_.end(); - CHECK(add_output || in_operand_list); - const std::vector& fused_parameters = - fused_instructions_computation()->parameter_instructions(); - for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { - if (instruction_to_fuse == operands_[operand_num]) { - // replace the fused parameter instruction's uses with the clone. - HloInstruction* fused_parameter = fused_parameters[operand_num]; - TF_CHECK_OK(fused_parameter->ReplaceAllUsesWith(clone)); - - // Remove the corresponding fused parameter and operand from their - // respective vectors. - TF_CHECK_OK( - fused_instructions_computation()->RemoveParameter(operand_num)); - operands_.erase(operands_.begin() + operand_num); - break; - } - } - // We've cloned instruction_to_fuse into this fusion instruction, so this - // fusion instruction is no longer a use of instruction_to_fuse. - if (in_operand_list) { - instruction_to_fuse->RemoveUser(this); - // When the instruction_to_fuse does not have other users, we don't need - // to generate a multioutput fusion instruction. - if (instruction_to_fuse->user_count() == 0) { - add_output = false; - } - } - } - - // Reread the parameters in the computation. - const std::vector& fused_parameters = - fused_instructions_computation()->parameter_instructions(); - - // Add each operand of the clone as an operand of the fusion instruction. A - // complication is that some clone operands may already be operands of the - // fusion instruction. - for (int64 operand_num = 0; operand_num < clone->operand_count(); - ++operand_num) { - HloInstruction* operand = clone->mutable_operand(operand_num); - - // See if this operand is already an operand of the fusion node. - CHECK_EQ(operands_.size(), fused_parameters.size()); - HloInstruction* fused_param = nullptr; - for (int64 i = 0; i < operands_.size(); ++i) { - if (operands_[i] == operand) { - fused_param = fused_parameters[i]; - break; - } - } - - if (fused_param == nullptr) { - // Clone's operand was not already an operand of the fusion - // instruction. Add it as an operand and add a corresponding fused - // parameter instruction. - fused_param = AddFusionOperand(operand); - } - TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param)); - } - - if (add_output) { - CHECK_GT(instruction_to_fuse->user_count(), 0); - // If this is already a multioutput fusion instruction, expand the root - // tuple by 1. - HloInstruction* fused_root = fused_expression_root(); - HloInstruction::InstructionVector tuple_elements; - bool newly_created_tuple_instr = false; - if (fused_root->opcode() == HloOpcode::kTuple) { - tuple_elements = fused_root->operands(); - } else { - tuple_elements.push_back(fused_root); - newly_created_tuple_instr = true; - } - if (clone->opcode() == HloOpcode::kTuple) { - for (auto inst : clone->operands()) { - tuple_elements.push_back(inst); - } - } else { - tuple_elements.push_back(clone); - } - HloInstruction* new_root = fused_instructions_computation()->AddInstruction( - HloInstruction::CreateTuple(tuple_elements)); - fused_instructions_computation()->set_root_instruction(new_root); - shape_ = new_root->shape(); - if (fused_root->opcode() == HloOpcode::kTuple) { - TF_CHECK_OK( - fused_instructions_computation()->RemoveInstruction(fused_root)); - } - - // If this is a newly created multioutput instruction, we need to update - // the use of the original fusion instruction. - if (newly_created_tuple_instr) { - HloInstruction* new_instr = parent_->AddInstruction( - HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0)); - TF_CHECK_OK(ReplaceAllUsesWith(new_instr)); - } - int64 index = tuple_elements.size(); - if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { - index -= instruction_to_fuse->operand_count(); - std::vector to_be_removed; - for (auto old_gte : instruction_to_fuse->users()) { - CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement); - int64 old_tuple_index = old_gte->tuple_index(); - HloInstruction* new_gte = - parent_->AddInstruction(HloInstruction::CreateGetTupleElement( - old_gte->shape(), this, index + old_tuple_index)); - TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte)); - to_be_removed.push_back(old_gte); - } - for (auto old_gte : to_be_removed) { - TF_CHECK_OK(parent_->RemoveInstruction(old_gte)); - } - TF_CHECK_OK(fused_instructions_computation()->RemoveInstruction(clone)); - } else { - HloInstruction* new_gte = - parent_->AddInstruction(HloInstruction::CreateGetTupleElement( - clone->shape(), this, index - 1)); - TF_CHECK_OK(instruction_to_fuse->ReplaceAllUsesWith(new_gte)); - } - } - - VLOG(2) << "New clone:\n" << clone->ToString(); - return clone; -} - -RandomDistribution HloInstruction::random_distribution() const { - CHECK_EQ(opcode_, HloOpcode::kRng); - return distribution_; -} - bool HloInstruction::HasSideEffectNoRecurse() const { switch (opcode_) { case HloOpcode::kSend: @@ -1171,26 +955,15 @@ bool HloInstruction::HasSideEffect() const { /* static */ std::unique_ptr HloInstruction::CreateCustomCall( const Shape& shape, tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece custom_call_target) { - std::unique_ptr instruction = - WrapUnique(new HloInstruction(HloOpcode::kCustomCall, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->custom_call_target_ = std::string(custom_call_target); - return instruction; + return MakeUnique(shape, operands, + custom_call_target); } /* static */ std::unique_ptr HloInstruction::CreateHostCompute( const Shape& shape, tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) { - std::unique_ptr instruction = - WrapUnique(new HloInstruction(HloOpcode::kHostCompute, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->channel_name_ = std::string(channel_name); - instruction->cost_estimate_ns_ = cost_estimate_ns; - return instruction; + return MakeUnique(shape, operands, channel_name, + cost_estimate_ns); } /* static */ std::unique_ptr HloInstruction::CreateTuple( @@ -1263,6 +1036,42 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( // in the face of code changes than copying fields explicitly. This also // properly sets the user fields of the operands. switch (opcode_) { + // Ops migrated to subclasses. + // TODO(b/80131774): Remove this switch when migration is complete. + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormInference: + case HloOpcode::kBatchNormGrad: + case HloOpcode::kFft: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kReverse: + case HloOpcode::kConcatenate: + case HloOpcode::kReduce: + case HloOpcode::kTranspose: + case HloOpcode::kBroadcast: + case HloOpcode::kMap: + case HloOpcode::kSlice: + case HloOpcode::kConstant: + case HloOpcode::kTrace: + case HloOpcode::kFusion: + case HloOpcode::kRng: + case HloOpcode::kParameter: + case HloOpcode::kGetTupleElement: + case HloOpcode::kReducePrecision: + case HloOpcode::kCrossReplicaSum: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kConvolution: + case HloOpcode::kCustomCall: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kHostCompute: + case HloOpcode::kPad: + case HloOpcode::kDynamicSlice: + clone = CloneWithNewOperandsImpl(shape, new_operands, context); + break; // Unary ops. case HloOpcode::kAbs: case HloOpcode::kRoundNearestAfz: @@ -1321,31 +1130,9 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( new_operands[2]); break; // Other supported ops. - case HloOpcode::kBroadcast: - CHECK_EQ(new_operands.size(), 1); - clone = CreateBroadcast(shape, new_operands[0], dimensions_); - break; case HloOpcode::kCall: clone = CreateCall(shape, new_operands, to_apply()); break; - case HloOpcode::kCustomCall: - clone = CreateCustomCall(shape, new_operands, custom_call_target_); - if (window_ != nullptr) { - clone->window_ = MakeUnique(*window_); - } - if (convolution_dimension_numbers_ != nullptr) { - clone->convolution_dimension_numbers_ = - MakeUnique( - *convolution_dimension_numbers_); - } - break; - case HloOpcode::kHostCompute: - clone = CreateHostCompute(shape, new_operands, channel_name_, - cost_estimate_ns_); - break; - case HloOpcode::kConcatenate: - clone = CreateConcatenate(shape, new_operands, dimensions(0)); - break; case HloOpcode::kConvert: CHECK_EQ(new_operands.size(), 1); clone = CreateConvert(shape, new_operands[0]); @@ -1354,85 +1141,20 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateBitcastConvert(shape, new_operands[0]); break; - case HloOpcode::kReducePrecision: - CHECK_EQ(new_operands.size(), 1); - clone = CreateReducePrecision(shape, new_operands[0], exponent_bits_, - mantissa_bits_); - break; - case HloOpcode::kConvolution: - CHECK_EQ(new_operands.size(), 2); - clone = CreateConvolve(shape, new_operands[0], new_operands[1], *window_, - *convolution_dimension_numbers_); - break; case HloOpcode::kDot: CHECK_EQ(new_operands.size(), 2); clone = CreateDot(shape, new_operands[0], new_operands[1], *dot_dimension_numbers_); break; - case HloOpcode::kFft: - CHECK_EQ(new_operands.size(), 1); - clone = CreateFft(shape, new_operands[0], fft_type_, fft_length_); - break; - case HloOpcode::kCrossReplicaSum: - clone = CreateCrossReplicaSum(shape, new_operands); - break; - case HloOpcode::kGetTupleElement: - CHECK_EQ(new_operands.size(), 1); - clone = CreateGetTupleElement(shape, new_operands[0], tuple_index()); - break; - case HloOpcode::kMap: - clone = CreateMap(shape, new_operands, to_apply()); - break; - case HloOpcode::kPad: - CHECK_EQ(new_operands.size(), 2); - clone = - CreatePad(shape, new_operands[0], new_operands[1], *padding_config_); - break; - case HloOpcode::kReduce: - CHECK_EQ(new_operands.size(), 2); - clone = CreateReduce(shape, new_operands[0], new_operands[1], dimensions_, - to_apply()); - break; - case HloOpcode::kReduceWindow: - CHECK_EQ(new_operands.size(), 2); - clone = CreateReduceWindow(shape, new_operands[0], new_operands[1], - *window_, to_apply()); - break; - case HloOpcode::kSelectAndScatter: - CHECK_EQ(new_operands.size(), 3); - clone = - CreateSelectAndScatter(shape, new_operands[0], select(), *window_, - new_operands[1], new_operands[2], scatter()); - break; - case HloOpcode::kReverse: - CHECK_EQ(new_operands.size(), 1); - clone = CreateReverse(shape, new_operands[0], dimensions_); - break; - case HloOpcode::kRng: - clone = CreateRng(shape, distribution_, new_operands); - break; case HloOpcode::kReshape: CHECK_EQ(new_operands.size(), 1); clone = CreateReshape(shape, new_operands[0]); break; - case HloOpcode::kSlice: - CHECK_EQ(new_operands.size(), 1); - clone = CreateSlice(shape, new_operands[0], slice_starts_, slice_limits_, - slice_strides_); - break; - case HloOpcode::kDynamicSlice: - clone = CreateDynamicSlice(shape, new_operands[0], new_operands[1], - dynamic_slice_sizes_); - break; case HloOpcode::kDynamicUpdateSlice: CHECK_EQ(new_operands.size(), 3); clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1], new_operands[2]); break; - case HloOpcode::kTranspose: - CHECK_EQ(new_operands.size(), 1); - clone = CreateTranspose(shape, new_operands[0], dimensions_); - break; case HloOpcode::kTuple: clone = CreateTuple(new_operands); *clone->mutable_shape() = shape; @@ -1442,78 +1164,12 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateWhile(shape, while_condition(), while_body(), new_operands[0]); break; - case HloOpcode::kConstant: - clone = CreateConstant(literal_->CloneToUnique()); - break; - case HloOpcode::kFusion: { - HloModule* module = context != nullptr ? context->module() : GetModule(); - HloComputation* new_fused_computation = nullptr; - if (context != nullptr) { - new_fused_computation = - context->FindComputation(fused_instructions_computation()); - } - if (new_fused_computation == nullptr) { - new_fused_computation = module->AddEmbeddedComputation( - fused_instructions_computation()->Clone("clone", context)); - } - clone = CreateFusion(/*shape=*/shape, /*fusion_kind=*/fusion_kind(), - /*operands=*/new_operands, - /*fusion_computation=*/new_fused_computation); - break; - } - case HloOpcode::kParameter: - clone = CreateParameter(parameter_number_, shape, name_); - break; - case HloOpcode::kBatchNormTraining: - CHECK_EQ(new_operands.size(), 3); - clone = - CreateBatchNormTraining(shape, new_operands[0], new_operands[1], - new_operands[2], epsilon(), feature_index()); - break; - case HloOpcode::kBatchNormInference: - CHECK_EQ(new_operands.size(), 5); - clone = CreateBatchNormInference( - shape, new_operands[0], new_operands[1], new_operands[2], - new_operands[3], new_operands[4], epsilon(), feature_index()); - break; - case HloOpcode::kInfeed: - CHECK_EQ(new_operands.size(), 0); - clone = CreateInfeed(shape, infeed_config()); - break; - case HloOpcode::kOutfeed: - CHECK_EQ(new_operands.size(), 1); - clone = CreateOutfeed(outfeed_shape_, new_operands[0], outfeed_config()); - break; - case HloOpcode::kBatchNormGrad: - CHECK_EQ(new_operands.size(), 5); - clone = CreateBatchNormGrad(shape, new_operands[0], new_operands[1], - new_operands[2], new_operands[3], - new_operands[4], epsilon(), feature_index()); - break; case HloOpcode::kConditional: CHECK_EQ(new_operands.size(), 3); clone = CreateConditional(shape, new_operands[0], new_operands[1], true_computation(), new_operands[2], false_computation()); break; - case HloOpcode::kSend: - CHECK_EQ(new_operands.size(), 1); - clone = CreateSend(new_operands[0], channel_id()); - break; - case HloOpcode::kSendDone: - CHECK_EQ(new_operands.size(), 1); - clone = CreateSendDone(new_operands[0]); - break; - case HloOpcode::kRecv: - CHECK_EQ(new_operands.size(), 0); - // The shape is a tuple, but CreateRecv() wants the raw data shape. - clone = - CreateRecv(ShapeUtil::GetTupleElementShape(shape, 0), channel_id()); - break; - case HloOpcode::kRecvDone: - CHECK_EQ(new_operands.size(), 1); - clone = CreateRecvDone(new_operands[0]); - break; case HloOpcode::kGather: CHECK_EQ(new_operands.size(), 2); clone = CreateGather(shape, new_operands[0], new_operands[1], @@ -1525,8 +1181,9 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(), user_side_metadata_->Clone()); break; - case HloOpcode::kTrace: - LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); + case HloOpcode::kGenerateToken: + clone = CreateGenerateToken(new_operands); + break; } SetupDerivedInstruction(clone.get()); clone->set_parent(parent_); @@ -1542,7 +1199,29 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( return clone; } -HloInstruction::~HloInstruction() {} +HloInstruction::~HloInstruction() { + // Detach from operands. An instruction may be repeated as an operand. To + // avoid calling RemoveUser twice on the same operand, check before remove. + for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { + HloInstruction* operand = operands_[operand_num]; + if (operand == nullptr) { + continue; + } + if (operand->user_set_.find(this) != operand->user_set_.end()) { + operand->RemoveUser(this); + } + operands_[operand_num] = nullptr; + } + + // Update users. Set `nullptr` to the correpsonding operand slot for users. + for (auto& user : this->users()) { + for (int i = 0; i < user->operand_count(); ++i) { + if (user->operands_[i] == this) { + user->operands_[i] = nullptr; + } + } + } +} std::unique_ptr HloInstruction::Clone( const string& suffix, HloCloneContext* context) const { @@ -1607,40 +1286,6 @@ const HloInstruction* HloInstruction::LatestNonGteAncestor() const { return hlo; } -const Literal& HloInstruction::literal() const { - CHECK_EQ(HloOpcode::kConstant, opcode_); - return *literal_; -} - -bool HloInstruction::HasLiteral() const { return literal_ != nullptr; } - -bool HloInstruction::CanHaveDimensionsField() const { - return (opcode() == HloOpcode::kReverse || - opcode() == HloOpcode::kConcatenate || - opcode() == HloOpcode::kReduce || opcode() == HloOpcode::kBroadcast || - opcode() == HloOpcode::kTranspose); -} - -const std::vector& HloInstruction::dimensions() const { - CHECK(CanHaveDimensionsField()); - return dimensions_; -} - -int64 HloInstruction::dimensions(int64 index) const { - return dimensions()[index]; -} - -int64 HloInstruction::concatenate_dimension() const { - CHECK(opcode() == HloOpcode::kConcatenate); - CHECK_EQ(1, dimensions_.size()); - return dimensions(0); -} - -int64 HloInstruction::tuple_index() const { - CHECK_EQ(HloOpcode::kGetTupleElement, opcode_); - return tuple_index_; -} - const HloInstruction* HloInstruction::operand(int64 i) const { return operands_[i]; } @@ -1729,10 +1374,6 @@ void HloInstruction::AddUser(HloInstruction* user) { } } -bool HloInstruction::IsConstant() const { - return opcode_ == HloOpcode::kConstant; -} - bool HloInstruction::HasConstantOperand() const { for (const HloInstruction* operand : operands_) { if (operand->IsConstant()) { @@ -1762,9 +1403,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kConvert: case HloOpcode::kCopy: case HloOpcode::kCos: - case HloOpcode::kCrossReplicaSum: case HloOpcode::kDivide: - case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kEq: case HloOpcode::kExp: @@ -1802,48 +1441,12 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kTuple: return true; - // Broadcast, Concatenate, and Transpose need the same dimensions field. - case HloOpcode::kBroadcast: - case HloOpcode::kConcatenate: - case HloOpcode::kTranspose: - return dimensions() == other.dimensions(); - - case HloOpcode::kFusion: - return fusion_kind() == other.fusion_kind() && - eq_computations(fused_instructions_computation(), - other.fused_instructions_computation()); - // These opcodes have complex or special behavior so just return false. case HloOpcode::kDomain: - case HloOpcode::kRng: - case HloOpcode::kTrace: case HloOpcode::kWhile: + case HloOpcode::kGenerateToken: return false; - case HloOpcode::kParameter: - return parameter_number() == other.parameter_number(); - - case HloOpcode::kBatchNormTraining: - case HloOpcode::kBatchNormInference: - case HloOpcode::kBatchNormGrad: - return feature_index() == other.feature_index() && - epsilon() == other.epsilon(); - - // A constant is defined by the value in the literal. - case HloOpcode::kConstant: - return literal() == other.literal(); - - // A reduce-precision operation is determined by the bit sizes. - case HloOpcode::kReducePrecision: - return exponent_bits() == other.exponent_bits() && - mantissa_bits() == other.mantissa_bits(); - - // Convolution has a window and dimensions. - case HloOpcode::kConvolution: - return protobuf_util::ProtobufEquals(window(), other.window()) && - protobuf_util::ProtobufEquals( - convolution_dimension_numbers(), - other.convolution_dimension_numbers()); // Check dot dimension numbers. case HloOpcode::kDot: return protobuf_util::ProtobufEquals(dot_dimension_numbers(), @@ -1854,83 +1457,56 @@ bool HloInstruction::IdenticalSlowPath( other.gather_dimension_numbers()) && gather_window_bounds() == other.gather_window_bounds(); - // FFT has various types & lengths. - case HloOpcode::kFft: - return fft_type() == other.fft_type() && - fft_length() == other.fft_length(); - - // Reduction results are determined by the reduction dimension and the - // reduction computation. - case HloOpcode::kReduce: - return dimensions() == other.dimensions() && - eq_computations(to_apply(), other.to_apply()); - case HloOpcode::kReduceWindow: - return eq_computations(to_apply(), other.to_apply()) && - protobuf_util::ProtobufEquals(window(), other.window()); - - // SelectAndScatter is determined by both select and scatter - // computation as well as the window configuration. - case HloOpcode::kSelectAndScatter: - return eq_computations(select(), other.select()) && - eq_computations(scatter(), other.scatter()) && - protobuf_util::ProtobufEquals(window(), other.window()); - - // Remaining instructions with special values. - case HloOpcode::kGetTupleElement: - return tuple_index() == other.tuple_index(); - case HloOpcode::kPad: - return protobuf_util::ProtobufEquals(padding_config(), - other.padding_config()); - case HloOpcode::kSlice: - return slice_starts_ == other.slice_starts_ && - slice_limits_ == other.slice_limits_ && - slice_strides_ == other.slice_strides_; case HloOpcode::kCall: - case HloOpcode::kMap: return eq_computations(to_apply(), other.to_apply()); - case HloOpcode::kCustomCall: - if ((window_ == nullptr) != (other.window_ == nullptr) || - (window_ != nullptr && - !protobuf_util::ProtobufEquals(window(), other.window()))) { - return false; - } - if ((convolution_dimension_numbers_ == nullptr) != - (other.convolution_dimension_numbers_ == nullptr) || - (convolution_dimension_numbers_ != nullptr && - !protobuf_util::ProtobufEquals( - convolution_dimension_numbers(), - other.convolution_dimension_numbers()))) { - return false; - } - return custom_call_target_ == other.custom_call_target_; - case HloOpcode::kReverse: - return dimensions() == other.dimensions(); case HloOpcode::kConditional: return eq_computations(true_computation(), other.true_computation()) && eq_computations(false_computation(), other.false_computation()); // These opcodes are not yet supported. - case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: case HloOpcode::kSort: - case HloOpcode::kRecv: - case HloOpcode::kRecvDone: + return false; + + // Ops migrated to subclasses should never come to this line. + // TODO(b/80131774): Remove this switch when migration is complete. + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormInference: + case HloOpcode::kBatchNormGrad: + case HloOpcode::kFft: case HloOpcode::kSend: case HloOpcode::kSendDone: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kReverse: + case HloOpcode::kConcatenate: + case HloOpcode::kReduce: + case HloOpcode::kTranspose: + case HloOpcode::kBroadcast: + case HloOpcode::kMap: + case HloOpcode::kSlice: + case HloOpcode::kConstant: + case HloOpcode::kTrace: + case HloOpcode::kFusion: + case HloOpcode::kRng: + case HloOpcode::kParameter: + case HloOpcode::kGetTupleElement: + case HloOpcode::kReducePrecision: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kCrossReplicaSum: + case HloOpcode::kConvolution: + case HloOpcode::kCustomCall: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: case HloOpcode::kHostCompute: - return false; + case HloOpcode::kPad: + case HloOpcode::kDynamicSlice: + LOG(FATAL) << "Base class impl called for opcode with subclass: " + << opcode(); } } -bool HloInstruction::IsRank2Transpose() const { - return (opcode_ == HloOpcode::kTranspose) && - dimensions_ == std::vector({1, 0}) && - shape_.dimensions_size() == 2 && - std::equal(shape_.dimensions().begin(), shape_.dimensions().end(), - operands_[0]->shape_.dimensions().rbegin()); -} - void HloInstruction::RemoveUser(HloInstruction* user) { auto set_it = user_set_.find(user); CHECK(set_it != user_set_.end()); @@ -2012,28 +1588,13 @@ Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) { return Status::OK(); } -void HloInstruction::DetachFromOperands() { - VLOG(3) << "DetachFromOperands:\n " << ToString(); - CHECK_EQ(0, user_count()); - // An instruction may be repeated as an operand. To avoid calling RemoveUser - // twice on the same operand, keep a set of already detached operands. - std::set detached_operands; - for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { - HloInstruction* operand = operands_[operand_num]; - if (!ContainsKey(detached_operands, operand)) { - operand->RemoveUser(this); - detached_operands.insert(operand); - } - operands_[operand_num] = nullptr; - } -} - HloComputation* HloInstruction::to_apply() const { switch (opcode_) { case HloOpcode::kCall: case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: + case HloOpcode::kCrossReplicaSum: CHECK_EQ(called_computations_.size(), 1); return called_computations_[0]; default: @@ -2051,6 +1612,7 @@ void HloInstruction::set_to_apply(HloComputation* computation) { case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: + case HloOpcode::kCrossReplicaSum: CHECK_EQ(called_computations_.size(), 1); called_computations_[0] = computation; break; @@ -2060,16 +1622,6 @@ void HloInstruction::set_to_apply(HloComputation* computation) { } } -const string& HloInstruction::custom_call_target() const { - CHECK_EQ(opcode_, HloOpcode::kCustomCall); - return custom_call_target_; -} - -const string& HloInstruction::outfeed_config() const { - CHECK_EQ(opcode_, HloOpcode::kOutfeed); - return outfeed_config_; -} - HloComputation* HloInstruction::while_condition() const { CHECK_EQ(HloOpcode::kWhile, opcode_); return called_computations_[kConditionComputationIndex]; @@ -2096,32 +1648,6 @@ void HloInstruction::set_while_body(HloComputation* computation) { called_computations_[kBodyComputationIndex] = computation; } -HloComputation* HloInstruction::select() const { - CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - return called_computations_[kSelectComputationIndex]; -} - -HloComputation* HloInstruction::scatter() const { - CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - return called_computations_[kScatterComputationIndex]; -} - -void HloInstruction::set_select(HloComputation* computation) { - // Don't allow changing the computation for fused instructions so we don't - // have to recompute called_instructions for the entire fusion instruction. - CHECK(!IsFused()); - CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - called_computations_[kSelectComputationIndex] = computation; -} - -void HloInstruction::set_scatter(HloComputation* computation) { - // Don't allow changing the computation for fused instructions so we don't - // have to recompute called_instructions for the entire fusion instruction. - CHECK(!IsFused()); - CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_); - called_computations_[kScatterComputationIndex] = computation; -} - HloComputation* HloInstruction::true_computation() const { CHECK_EQ(HloOpcode::kConditional, opcode_); return called_computations_[kTrueComputationIndex]; @@ -2169,6 +1695,71 @@ string HloInstruction::ToString(const HloPrintOptions& options) const { return ToStringWithCanonicalNameMap(options, &new_map); } +bool HloInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + switch (opcode_) { + // Unary elementwise operations. + case HloOpcode::kAbs: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kCeil: + case HloOpcode::kClz: + case HloOpcode::kConvert: + case HloOpcode::kBitcastConvert: + case HloOpcode::kCopy: + case HloOpcode::kCos: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kFloor: + case HloOpcode::kImag: + case HloOpcode::kIsFinite: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kNot: + case HloOpcode::kNegate: + case HloOpcode::kReal: + case HloOpcode::kReducePrecision: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kTanh: + CHECK_EQ(1, operand_count()); + return true; + + // Binary elementwise operations, the same as in IsElementwiseBinary(). + case HloOpcode::kAdd: + case HloOpcode::kAtan2: + case HloOpcode::kComplex: + case HloOpcode::kDivide: + case HloOpcode::kEq: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLe: + case HloOpcode::kLt: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kNe: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kSubtract: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: + CHECK_EQ(2, operand_count()); + return true; + + // Ternary elementwise operations. + case HloOpcode::kSelect: + return !ShapeUtil::IsTuple(shape_); + case HloOpcode::kClamp: + return true; + + default: + return false; + } +} + string HloInstruction::ToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { @@ -2219,112 +1810,45 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { string operands; - if (opcode() == HloOpcode::kConstant) { - // For constants, show the actual value in place of an empty operand list. - // - // In HloInstruction, sometimes a constant literal is not constructed due - // to its size. Skip the printing in this case. - if (HasLiteral() && ((!ShapeUtil::IsTuple(shape()) && - ShapeUtil::ElementsIn(shape()) <= 10) || - options.print_large_constants())) { - // Literal::ToString emits multidimensional arrays over multiple - // lines. Compact this into one line by stripping out white space. - string tmp = literal().ToString(); - std::replace(tmp.begin(), tmp.end(), '\n', ' '); - std::vector v = tensorflow::str_util::Split(tmp, ' '); - bool first = true; - // Concatenate elements in "v" with spaces separating them, but ignoring - // empty entries. - for (const auto& s : v) { - if (s.empty()) { - continue; - } - StrAppend(&operands, (first ? "" : " "), s); - first = false; - } - } else { - // Do not show large constants or tuples. - operands = "{...}"; + tensorflow::gtl::ArraySlice slice(operands_); + const int64 kMaxOperandsToShowIfCompact = 4; + if (options.compact_operands() && + slice.size() > kMaxOperandsToShowIfCompact) { + slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact); + } + operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { + // If operand is already been deleted, put `null` to the string output. + if (operand == nullptr) { + StrAppend(out, "null "); + return; } - } else if (opcode() == HloOpcode::kParameter) { - StrAppend(&operands, parameter_number_); - } else { - tensorflow::gtl::ArraySlice slice(operands_); - const int64 kMaxOperandsToShowIfCompact = 4; - if (options.compact_operands() && - slice.size() > kMaxOperandsToShowIfCompact) { - slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact); + std::vector str; + if (options.print_operand_shape()) { + str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape())); } - operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { - std::vector str; - if (options.print_operand_shape()) { - str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape())); - } - // In a top-level HloInstruction::ToString() call, the operand name is not - // part of the canonical string. - if (options.canonicalize_instruction_names() && - options.is_in_nested_computation()) { - str.push_back(PrintName( - canonical_name_map->LookupOrInsert(operand->name()), options)); - } else if (!options.compact_operands()) { - str.push_back(PrintName(operand->name(), options)); - } - StrAppend(out, Join(str, " ")); - }); - const int64 remaining = operands_.size() - slice.size(); - if (slice.size() != operands_.size()) { - StrAppend(&operands, ", ...(+", remaining, ")"); + // In a top-level HloInstruction::ToString() call, the operand name is not + // part of the canonical string. + if (options.canonicalize_instruction_names() && + options.is_in_nested_computation()) { + str.push_back(PrintName( + canonical_name_map->LookupOrInsert(operand->name()), options)); + } else if (!options.compact_operands()) { + str.push_back(PrintName(operand->name(), options)); } + StrAppend(out, Join(str, " ")); + }); + const int64 remaining = operands_.size() - slice.size(); + if (slice.size() != operands_.size()) { + StrAppend(&operands, ", ...(+", remaining, ")"); } return operands; } std::vector HloInstruction::ExtraAttributesToString( const HloPrintOptions& options) const { - std::vector extra; - if (opcode() == HloOpcode::kFusion) { - extra.push_back(StrCat("kind=", xla::ToString(fusion_kind()))); - } - if (CanHaveDimensionsField()) { - extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}")); - } - if (window_ != nullptr && window_->dimensions_size() != 0) { - extra.push_back(StrCat("window={", window_util::ToString(*window_), "}")); - } - if (padding_config_ != nullptr) { - extra.push_back( - StrCat("padding=", xla::PaddingConfigToString(*padding_config_))); - } - if (opcode() == HloOpcode::kSlice) { - std::vector bounds; - bounds.reserve(slice_starts_.size()); - const bool omit_stride = - std::all_of(slice_strides_.begin(), slice_strides_.end(), - [](int64 stride) { return stride == 1; }); - for (int i = 0; i < slice_starts_.size(); ++i) { - string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]); - bounds.push_back(StrCat("[", slice_starts_[i], ":", slice_limits_[i], - stride_str, "]")); - } - extra.push_back(StrCat("slice={", Join(bounds, ", "), "}")); - } - if (opcode() == HloOpcode::kDynamicSlice) { - extra.push_back( - StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")); - } - if (opcode() == HloOpcode::kBatchNormTraining || - opcode() == HloOpcode::kBatchNormInference || - opcode() == HloOpcode::kBatchNormGrad) { - extra.push_back(StrCat("epsilon=", epsilon())); - extra.push_back(StrCat("feature_index=", feature_index())); - } + std::vector extra = ExtraAttributesToStringImpl(options); - if (convolution_dimension_numbers_ != nullptr) { - extra.push_back(StrCat( - "dim_labels=", - ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_))); - } if (dot_dimension_numbers_ != nullptr) { extra.push_back(DotDimensionNumbersToString()); } @@ -2333,10 +1857,6 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back( StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")); } - if (opcode() == HloOpcode::kFft) { - extra.push_back(StrCat("fft_type=", FftType_Name(fft_type()))); - extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}")); - } if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kNameOnly) { @@ -2356,7 +1876,8 @@ std::vector HloInstruction::ExtraAttributesToString( PrintName(false_computation()->name(), options))); } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || opcode() == HloOpcode::kReduceWindow || - opcode() == HloOpcode::kReduce) { + opcode() == HloOpcode::kReduce || + opcode() == HloOpcode::kCrossReplicaSum) { extra.push_back( StrCat("to_apply=", PrintName(to_apply()->name(), options))); } else if (!called_computations().empty()) { @@ -2391,6 +1912,7 @@ std::vector HloInstruction::ExtraAttributesToString( case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: + case HloOpcode::kCrossReplicaSum: extra.push_back( StrCat("to_apply=\n", to_apply()->ToString(new_options))); break; @@ -2406,14 +1928,7 @@ std::vector HloInstruction::ExtraAttributesToString( break; } } - if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv || - opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) { - extra.push_back(StrCat("channel_id=", channel_id_)); - } - if (opcode() == HloOpcode::kGetTupleElement) { - extra.push_back(StrCat("index=", tuple_index())); - } if (has_sharding()) { extra.push_back(StrCat("sharding=", sharding().ToString())); } @@ -2426,33 +1941,11 @@ std::vector HloInstruction::ExtraAttributesToString( }), "}")); } - if (opcode() == HloOpcode::kInfeed && !infeed_config_.empty()) { - extra.push_back(StrCat("infeed_config=\"", CEscape(infeed_config_), "\"")); - } - if (opcode() == HloOpcode::kOutfeed && !outfeed_config_.empty()) { - extra.push_back( - StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")); - } - if (opcode() == HloOpcode::kRng) { - extra.push_back( - StrCat("distribution=", RandomDistributionToString(distribution_))); - } - if (opcode() == HloOpcode::kReducePrecision) { - extra.push_back(StrCat("exponent_bits=", exponent_bits_)); - extra.push_back(StrCat("mantissa_bits=", mantissa_bits_)); - } if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(), "\", entry=", operand_side_metadata_->ToString(), ", exit=", user_side_metadata_->ToString(), "}")); } - // By contract, we print the custom call target even if - // options.print_subcomputation_mode() == kOff, because the call target is not - // an HloComputation. - if (opcode() == HloOpcode::kCustomCall) { - extra.push_back( - StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); - } return extra; } @@ -2484,31 +1977,12 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_metadata() = metadata_; proto.set_backend_config(backend_config_); - if (literal_ != nullptr) { - *proto.mutable_literal() = literal_->ToProto(); - } - proto.set_parameter_number(parameter_number_); - if (opcode() == HloOpcode::kFusion) { - proto.set_fusion_kind(xla::ToString(fusion_kind())); - proto.add_called_computation_ids( - fused_instructions_computation()->unique_id()); - } else { + if (opcode() != HloOpcode::kFusion) { for (const HloComputation* computation : called_computations_) { proto.add_called_computation_ids(computation->unique_id()); } } - proto.set_tuple_index(tuple_index_); - for (int64 dimension : dimensions_) { - proto.add_dimensions(dimension); - } - if (window_ != nullptr) { - *proto.mutable_window() = *window_; - } - if (convolution_dimension_numbers_ != nullptr) { - *proto.mutable_convolution_dimension_numbers() = - *convolution_dimension_numbers_; - } if (dot_dimension_numbers_ != nullptr) { *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_; } @@ -2520,42 +1994,11 @@ HloInstructionProto HloInstruction::ToProto() const { proto.add_gather_window_bounds(bound); } } - for (int i = 0; i < slice_starts_.size(); ++i) { - auto* slice_dimension = proto.add_slice_dimensions(); - slice_dimension->set_start(slice_starts_[i]); - slice_dimension->set_limit(slice_limits_[i]); - slice_dimension->set_stride(slice_strides_[i]); - } - proto.set_exponent_bits(exponent_bits_); - proto.set_mantissa_bits(mantissa_bits_); - for (int64 slice_size : dynamic_slice_sizes_) { - proto.add_dynamic_slice_sizes(slice_size); - } - if (padding_config_ != nullptr) { - *proto.mutable_padding_config() = *padding_config_; - } - proto.set_outfeed_config(outfeed_config_); - if (opcode() == HloOpcode::kRng) { - proto.set_distribution(distribution_); - } - proto.set_epsilon(epsilon_); - proto.set_feature_index(feature_index_); - proto.set_channel_id(channel_id_); - proto.set_infeed_config(infeed_config_); - proto.set_custom_call_target(custom_call_target_); - *proto.mutable_outfeed_shape() = outfeed_shape_; - proto.set_fft_type(fft_type_); - for (int64 fft_len : fft_length_) { - proto.add_fft_length(fft_len); - } if (has_sharding()) { *proto.mutable_sharding() = sharding().ToProto(); } - proto.set_channel_name(channel_name_); - proto.set_cost_estimate_ns(cost_estimate_ns_); - return proto; } @@ -2565,35 +2008,6 @@ string HloInstruction::ToCategory() const { return "data formatting"; } - if (opcode() == HloOpcode::kConvolution) { - string category = "convolution"; - if (window_util::HasBaseDilation(window())) { - category += " base-dilated"; - } - if (window_util::HasWindowDilation(window())) { - category += " window-dilated"; - } - return category; - } - - // Give transpose-dot and backwards-conv fusions the categories "dot" and - // "convolution" so they match the categories of proper kDot and kConvolution - // ops. These fusion categories are really just a way of expressing a - // particular kind of dot or conv, so they should have the same category as a - // vanilla dot/conv. - if (opcode() == HloOpcode::kFusion) { - switch (fusion_kind()) { - case FusionKind::kLoop: - return "loop fusion"; - case FusionKind::kInput: - return "input fusion"; - case FusionKind::kOutput: - return "output fusion"; - case FusionKind::kCustom: - return "custom fusion"; - } - } - if (IsElementwise()) { return "non-fusion elementwise"; } @@ -2607,73 +2021,22 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) { trace_instruction_ = trace_instruction; } -string HloInstruction::TracingTag() const { - CHECK_EQ(HloOpcode::kTrace, opcode()); - CHECK(literal_ != nullptr); - return literal_->GetR1U8AsString(); -} - bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); } -bool HloInstruction::IsFusable() const { - // Instructions which are traced should not be fused. - if (tracing()) { - return false; - } - // Some kinds of instructions don't make sense to fuse. - switch (opcode_) { - case HloOpcode::kDomain: - case HloOpcode::kParameter: - return false; - // Side effecting instrutions cannot be fused. - default: - return !HasSideEffect(); - } -} - -HloComputation* HloInstruction::fused_instructions_computation() const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(!called_computations_.empty()); - auto* fused_instructions_computation = called_computations_.front(); - CHECK(fused_instructions_computation->IsFusionComputation()) - << "Computation " << fused_instructions_computation->name() - << " is not a fusion kind"; - return fused_instructions_computation; -} - -HloInstruction* HloInstruction::fused_expression_root() const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_instructions_computation()->root_instruction(); -} - -HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_instructions_computation()->parameter_instruction( - parameter_number); -} - -const std::vector& HloInstruction::fused_parameters() const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_instructions_computation()->parameter_instructions(); -} - -const tensorflow::gtl::iterator_range>::const_iterator>> -HloInstruction::fused_instructions() const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - const HloComputation* subcomp = fused_instructions_computation(); - return subcomp->instructions(); -} - -const tensorflow::gtl::iterator_range< - UnwrappingIterator>::iterator>> -HloInstruction::fused_instructions() { - CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_instructions_computation()->instructions(); -} - -int64 HloInstruction::fused_instruction_count() const { - return fused_instructions_computation()->instruction_count(); +bool HloInstruction::IsFusable() const { + // Instructions which are traced should not be fused. + if (tracing()) { + return false; + } + // Some kinds of instructions don't make sense to fuse. + switch (opcode_) { + case HloOpcode::kDomain: + case HloOpcode::kParameter: + return false; + // Side effecting instrutions cannot be fused. + default: + return !HasSideEffect(); + } } HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape) @@ -2854,6 +2217,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleGather(this); case HloOpcode::kDomain: return visitor->HandleDomain(this); + case HloOpcode::kGenerateToken: + return visitor->HandleGenerateToken(this); // These opcodes are not handled here. case HloOpcode::kTrace: @@ -3094,12 +2459,6 @@ Status HloInstruction::AcceptOrdered( return visitor->FinishVisit(this); } -const Shape& HloInstruction::outfeed_shape() const { - DCHECK_EQ(opcode_, HloOpcode::kOutfeed); - TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); - return outfeed_shape_; -} - const Shape& HloInstruction::shape() const { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); return shape_; @@ -3121,87 +2480,7 @@ bool HloInstruction::IsElementwiseBinary() const { } bool HloInstruction::IsElementwise() const { - switch (opcode_) { - // Nullary elementwise operations. - case HloOpcode::kConstant: - return true; - - // Unary elementwise operations. - case HloOpcode::kAbs: - case HloOpcode::kRoundNearestAfz: - case HloOpcode::kCeil: - case HloOpcode::kClz: - case HloOpcode::kConvert: - case HloOpcode::kBitcastConvert: - case HloOpcode::kCopy: - case HloOpcode::kCos: - case HloOpcode::kExp: - case HloOpcode::kExpm1: - case HloOpcode::kFloor: - case HloOpcode::kImag: - case HloOpcode::kIsFinite: - case HloOpcode::kLog: - case HloOpcode::kLog1p: - case HloOpcode::kNot: - case HloOpcode::kNegate: - case HloOpcode::kReal: - case HloOpcode::kReducePrecision: - case HloOpcode::kSign: - case HloOpcode::kSin: - case HloOpcode::kTanh: - CHECK_EQ(1, operand_count()); - return true; - - // Binary elementwise operations, the same as in IsElementwiseBinary(). - case HloOpcode::kAdd: - case HloOpcode::kAtan2: - case HloOpcode::kComplex: - case HloOpcode::kDivide: - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kMaximum: - case HloOpcode::kMinimum: - case HloOpcode::kMultiply: - case HloOpcode::kNe: - case HloOpcode::kPower: - case HloOpcode::kRemainder: - case HloOpcode::kSubtract: - case HloOpcode::kAnd: - case HloOpcode::kOr: - case HloOpcode::kShiftLeft: - case HloOpcode::kShiftRightArithmetic: - case HloOpcode::kShiftRightLogical: - CHECK_EQ(2, operand_count()); - return true; - - // Ternary elementwise operations. - case HloOpcode::kSelect: - return !ShapeUtil::IsTuple(shape_); - case HloOpcode::kClamp: - return true; - - // Other operations. - case HloOpcode::kRng: - case HloOpcode::kMap: - return true; - case HloOpcode::kFusion: - if (fusion_kind() != FusionKind::kLoop) { - return false; - } - for (auto* fused : fused_instructions()) { - if (fused->opcode() != HloOpcode::kParameter && - !fused->IsElementwise()) { - return false; - } - } - return true; - - default: - return false; - } + return IsElementwiseImpl(tensorflow::gtl::nullopt); } bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const { @@ -3209,54 +2488,8 @@ bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const { return !ShapeUtil::SameDimensions(shape(), operand(operand_idx)->shape()); } -namespace { -bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, - const HloInstruction* operand) { - std::vector operand_indices = instruction->OperandIndices(operand); - return std::all_of( - operand_indices.begin(), operand_indices.end(), - [instruction](int64 operand_index) { - return instruction->IsElementwiseOnOperand(operand_index); - }); -} -} // namespace - bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const { - // For all instructions other than kFusion, being elementwise on one of the - // operands is equivalent to being elementwise on all the operands. - if (opcode() != HloOpcode::kFusion) { - return IsElementwise(); - } - - CHECK_EQ(HloOpcode::kFusion, opcode()); - if (fusion_kind() != FusionKind::kLoop) { - return false; - } - - // A loop-fusion is elementwise on an operand if all operations (computed - // using BFS) between the operand and the fused root are elementwise. - std::deque worklist; - std::unordered_set visited; - worklist.push_back(fused_parameter(operand_idx)); - visited.insert(fused_parameter(operand_idx)); - while (!worklist.empty()) { - HloInstruction* operand = worklist.front(); - worklist.pop_front(); - for (HloInstruction* user : operand->users()) { - CHECK_GE(user->unique_id(), 0); - if (ContainsKey(visited, user)) { - continue; - } - if (user->IsElementwise() || - IsInstructionElementwiseOnOperand(user, operand)) { - worklist.push_back(user); - visited.insert(user); - } else { - return false; - } - } - } - return true; + return IsElementwiseImpl(operand_idx); } // A helper class for memoized, recursive computation of HloOpcode::kFusion @@ -3278,8 +2511,10 @@ class HloInstruction::FusionReusesParamElements { static UseKind ComputeInternal( int64 i, const HloInstruction& hlo, tensorflow::gtl::FlatMap* cache) { - if (hlo.opcode_ == HloOpcode::kParameter && hlo.parameter_number_ == i) { - return UseKind::kUse; + if (auto hlo_param = DynCast(&hlo)) { + if (hlo_param->parameter_number() == i) { + return UseKind::kUse; + } } auto p = cache->emplace(&hlo, UseKind{}); @@ -3588,21 +2823,264 @@ void HloInstruction::set_outer_dimension_partitions( outer_dimension_partitions_ = outer_dimension_partitions; } +// TODO(b/80131774): Remove these temporary methods after transition. +int64 HloInstruction::feature_index() const { + return Cast(this)->feature_index(); +} + +float HloInstruction::epsilon() const { + return Cast(this)->epsilon(); +} + +FftType HloInstruction::fft_type() const { + return Cast(this)->fft_type(); +} + +const std::vector& HloInstruction::fft_length() const { + return Cast(this)->fft_length(); +} + +int64 HloInstruction::channel_id() const { + return Cast(this)->channel_id(); +} + +int64 HloInstruction::concatenate_dimension() const { + return Cast(this)->concatenate_dimension(); +} + +bool HloInstruction::IsRank2Transpose() const { + auto transpose = DynCast(this); + return transpose != nullptr && transpose->IsRank2Transpose(); +} + +int64 HloInstruction::slice_starts(int64 dimension) const { + return Cast(this)->slice_starts(dimension); +} + +const std::vector& HloInstruction::slice_starts() const { + return Cast(this)->slice_starts(); +} + +int64 HloInstruction::slice_limits(int64 dimension) const { + return Cast(this)->slice_limits(dimension); +} + +const std::vector& HloInstruction::slice_limits() const { + return Cast(this)->slice_limits(); +} + +int64 HloInstruction::slice_strides(int64 dimension) const { + return Cast(this)->slice_strides(dimension); +} + +const std::vector& HloInstruction::slice_strides() const { + return Cast(this)->slice_strides(); +} + +bool HloInstruction::IsInPlaceSlice() const { + return Cast(this)->IsInPlaceSlice(); +} + +const Literal& HloInstruction::literal() const { + return Cast(this)->literal(); +} + +bool HloInstruction::IsConstant() const { + return DynCast(this) != nullptr; +} + void HloInstruction::RelayoutConstant(const Layout& new_layout, const ShapeIndex& shape_index) { - CHECK_EQ(opcode(), HloOpcode::kConstant); - Shape* mutable_array_subshape = - ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index); - CHECK(ShapeUtil::IsArray(*mutable_array_subshape)); + Cast(this)->RelayoutConstant(new_layout, shape_index); +} + +string HloInstruction::TracingTag() const { + return Cast(this)->TracingTag(); +} + +HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) { + return Cast(this)->AddFusionOperand(new_operand); +} + +// Delegates to HloFusionInstruction::MergeFusionInstruction. +void HloInstruction::MergeFusionInstruction( + HloInstruction* instruction_to_merge) { + return Cast(this)->MergeFusionInstruction( + Cast(instruction_to_merge)); +} + +// Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput. +void HloInstruction::MergeFusionInstructionIntoMultiOutput( + HloInstruction* instruction_to_merge) { + return Cast(this) + ->MergeFusionInstructionIntoMultiOutput( + Cast(instruction_to_merge)); +} + +HloInstruction* HloInstruction::FuseInstruction( + HloInstruction* instruction_to_fuse) { + return Cast(this)->FuseInstruction(instruction_to_fuse); +} + +HloInstruction* HloInstruction::FuseInstructionIntoMultiOutput( + HloInstruction* instruction_to_fuse) { + return Cast(this)->FuseInstructionIntoMultiOutput( + instruction_to_fuse); +} + +HloComputation* HloInstruction::fused_instructions_computation() const { + return Cast(this)->fused_instructions_computation(); +} + +HloInstruction* HloInstruction::fused_expression_root() const { + return Cast(this)->fused_expression_root(); +} + +const tensorflow::gtl::iterator_range>::const_iterator>> +HloInstruction::fused_instructions() const { + return Cast(this)->fused_instructions(); +} + +const tensorflow::gtl::iterator_range< + UnwrappingIterator>::iterator>> +HloInstruction::fused_instructions() { + return Cast(this)->fused_instructions(); +} + +int64 HloInstruction::fused_instruction_count() const { + return Cast(this)->fused_instruction_count(); +} + +HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const { + return Cast(this)->fused_parameter(parameter_number); +} + +const std::vector& HloInstruction::fused_parameters() const { + return Cast(this)->fused_parameters(); +} + +const bool HloInstruction::IsMultiOutputFusion() const { + const HloFusionInstruction* fusion = DynCast(this); + return fusion != nullptr && fusion->IsMultiOutputFusion(); +} + +HloInstruction::FusionKind HloInstruction::fusion_kind() const { + return Cast(this)->fusion_kind(); +} + +void HloInstruction::set_fusion_kind(FusionKind kind) { + return Cast(this)->set_fusion_kind(kind); +} + +RandomDistribution HloInstruction::random_distribution() const { + return Cast(this)->random_distribution(); +} + +int64 HloInstruction::parameter_number() const { + return Cast(this)->parameter_number(); +} + +int64 HloInstruction::tuple_index() const { + return Cast(this)->tuple_index(); +} + +int32 HloInstruction::exponent_bits() const { + return Cast(this)->exponent_bits(); +} + +int32 HloInstruction::mantissa_bits() const { + return Cast(this)->mantissa_bits(); +} + +string HloInstruction::infeed_config() const { + return Cast(this)->infeed_config(); +} + +void HloInstruction::set_infeed_config(const string& config) { + return Cast(this)->set_infeed_config(config); +} + +const Shape& HloInstruction::outfeed_shape() const { + return Cast(this)->outfeed_shape(); +} + +const string& HloInstruction::outfeed_config() const { + return Cast(this)->outfeed_config(); +} + +const std::vector& HloInstruction::replica_group_ids() const { + return Cast(this)->replica_group_ids(); +} + +string HloInstruction::cross_replica_sum_barrier() const { + return Cast(this)->cross_replica_sum_barrier(); +} + +void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) { + return Cast(this)->set_cross_replica_sum_barrier( + barrier); +} + +tensorflow::gtl::optional HloInstruction::all_reduce_id() const { + return Cast(this)->all_reduce_id(); +} - // Normally array_subshape will always have a layout, but this invariant is - // temporarily broken in LayoutAssignment::AssignLayouts. +const ConvolutionDimensionNumbers& +HloInstruction::convolution_dimension_numbers() const { + if (auto convolution = DynCast(this)) { + return convolution->convolution_dimension_numbers(); + } + if (auto custom_call = DynCast(this)) { + return custom_call->convolution_dimension_numbers(); + } + LOG(FATAL) << "Unimplemented method."; +} - if (!mutable_array_subshape->has_layout() || - !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) { - literal_ = literal_->Relayout(new_layout, shape_index); - *mutable_array_subshape->mutable_layout() = new_layout; +void HloInstruction::set_convolution_dimension_numbers( + const ConvolutionDimensionNumbers& dnums) { + if (auto convolution = DynCast(this)) { + convolution->set_convolution_dimension_numbers(dnums); + } else if (auto custom_call = DynCast(this)) { + custom_call->set_convolution_dimension_numbers(dnums); + } else { + LOG(FATAL) << "Unimplemented method."; } } +HloComputation* HloInstruction::select() const { + return Cast(this)->select(); +} + +HloComputation* HloInstruction::scatter() const { + return Cast(this)->scatter(); +} + +void HloInstruction::set_select(HloComputation* computation) { + return Cast(this)->set_select(computation); +} + +void HloInstruction::set_scatter(HloComputation* computation) { + return Cast(this)->set_scatter(computation); +} + +const string& HloInstruction::custom_call_target() const { + return Cast(this)->custom_call_target(); +} + +const string& HloInstruction::channel_name() const { + return Cast(this)->channel_name(); +} + +const PaddingConfig& HloInstruction::padding_config() const { + return Cast(this)->padding_config(); +} + +int64 HloInstruction::slice_sizes(int64 dimension) const { + return Cast(this)->slice_sizes(dimension); +} + +const std::vector& HloInstruction::dynamic_slice_sizes() const { + return Cast(this)->dynamic_slice_sizes(); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 905ea5310dd46d4b7e129f6ae4c5d4d0302bc13b..8f59e67123cadf965c8650f4a82622f5443ecac9 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -426,10 +426,27 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, const int exponent_bits, const int mantissa_bits); - // Creates a cross replica sum op. + // Creates a cross replica reduction op. + // + // `reduction_computation`: the reduction function. + // + // `replica_group_ids`: maps replica ids to subgroup ids. If empty, all + // replicas belong to one group. Allreduce will be applied within subgroups. + // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, + // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. + // + // `all_reduce_id`: for Allreduce nodes from different modules, if they have + // the same all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will + // not be applied cross modules. + // + // TODO(b/79737069): Rename this to AllReduce. static std::unique_ptr CreateCrossReplicaSum( - const Shape& shape, - tensorflow::gtl::ArraySlice operands); + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* reduce_computation, + tensorflow::gtl::ArraySlice replica_group_ids, + tensorflow::StringPiece barrier, + const tensorflow::gtl::optional& all_reduce_id = + tensorflow::gtl::nullopt); // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. @@ -648,6 +665,11 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); + // Creates a token instruction used for joining or creating token types which + // thread through side-effecting operations. + static std::unique_ptr CreateGenerateToken( + tensorflow::gtl::ArraySlice operands); + // Creates an instance of GatherDimensionNumbers. static GatherDimensionNumbers MakeGatherDimNumbers( tensorflow::gtl::ArraySlice output_window_dims, @@ -786,9 +808,6 @@ class HloInstruction { // Returns whether the instruction has a constant operand. bool HasConstantOperand() const; - // Returns whether this instruction does a rank-2 transposition. - bool IsRank2Transpose() const; - // Replaces the use of this instruction in "user" with "new_producer". Note // that there might be multiple uses of this instruction in "user"; all will // be replaced. @@ -805,13 +824,6 @@ class HloInstruction { // root to new_producer. Status ReplaceAllUsesWith(HloInstruction* new_producer); - // Detaches an instruction from its operands. That is, remove the instruction - // from each operand's user set. This should only be called prior to - // deallocating the instruction. - // - // TODO(b/78305363): Make this automatic when deleting an instruction. - void DetachFromOperands(); - // Performs a postorder DFS visit using this node as the root. If // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when // complete. If ignore_control_predecessors is true, instructions only @@ -857,38 +869,6 @@ class HloInstruction { template Status Visit(DfsHloVisitorBase* visitor); - // Returns the literal associated with this instruction. - // - // Note: only constant and parameter opcodes have an associated literal. - const Literal& literal() const; - - // Returns whether there is literal associated with this instruction. - bool HasLiteral() const; - - // Returns the parameter number associated with this instruction. - // - // Note: only parameter opcodes have an associated parameter number. - int64 parameter_number() const { - CHECK_EQ(HloOpcode::kParameter, opcode_); - return parameter_number_; - } - - // Returns the dimension sizes or numbers associated with this instruction. - // - // Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape, - // and reverse. - const std::vector& dimensions() const; - int64 dimensions(int64 index) const; - - // Accessor for the dimension in which a concatenate HLO should occur. - // Precondition: opcode() == HloOpcode::kConcatenate - int64 concatenate_dimension() const; - - // Returns the tuple index associated with this instruction. - // - // Precondition: opcode() == HloOpcode::kGetTupleElement - int64 tuple_index() const; - // Returns the first non-GetTupleElement ancestor instruction of 'hlo'. // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the // (possibly nested) tuple indices used on the path from ancestor to 'hlo'. @@ -916,18 +896,6 @@ class HloInstruction { HloComputation* to_apply() const; void set_to_apply(HloComputation* to_apply); - // Returns the custom_call_target for CustomCall. - // Precondition: opcode() == HloOpcode::kCustomCall - const string& custom_call_target() const; - - // Returns the config for the Outfeed instruction. - // Precondition: opcode() == HloOpcode::kOutfeed - const string& outfeed_config() const; - - // Returns the shape for the Outfeed instruction. - // Precondition: opcode() == HloOpcode::kOutfeed - const Shape& outfeed_shape() const; - // Gets/sets the while_condition or while_body HloComputation for While. The // setters should only be called by HloModule or HloComputation methods. // @@ -937,15 +905,6 @@ class HloInstruction { void set_while_condition(HloComputation* while_condition); void set_while_body(HloComputation* while_body); - // Gets/sets the select or scatter HloComputation for SelectAndScatter. The - // setters should only be called by HloModule or HloComputation methods. - // - // Precondition: opcode() == HloOpcode::kSelectAndScatter. - HloComputation* select() const; - HloComputation* scatter() const; - void set_select(HloComputation* select); - void set_scatter(HloComputation* scatter); - // Gets/sets the true and false HloComputation for Conditional. The setters // should only be called by HloModule or HloComputation methods. // @@ -983,11 +942,11 @@ class HloInstruction { string ToShortString() const; // Returns a serialized representation of this instruction. - HloInstructionProto ToProto() const; + virtual HloInstructionProto ToProto() const; // Returns a category for the HLO. This could be something like "convolution" // or "elementwise". - string ToCategory() const; + virtual string ToCategory() const; // Returns a logging instruction, if the output of this instruction is logged. // @@ -995,111 +954,14 @@ class HloInstruction { HloInstruction* tracing() const; void set_tracing(HloInstruction* trace_instruction); - // Returns the channel id associated with the instruction. The id is - // shared between each Send/Recv pair and is globally unique to identify each - // channel. - // - // Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv - int64 channel_id() const { return channel_id_; } - - // Returns the channel name associated with the instruction. The name is - // used to identify host Send/Recv operations. - // - // Precondition: opcode() == HloOpcode::kHostCompute - string channel_name() const { return channel_name_; } - - // Returns feature_index field associated with the instruction. The index - // represents the index of the feature dimension. - // - // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference, - // or kBatchNormGrad. - int64 feature_index() const { return feature_index_; } - - // Returns a epsilon value associated with the instruction. The is a small - // number added to the variance to avoid divide-by-zero error. - // - // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference, - // or kBatchNormGrad. - float epsilon() const { return epsilon_; } - - // Returns the infeed configuration string. The infeed configuration includes - // any metadata needed for the backend compiler (e.g., infeed buffer address) - // and is target-dependent. - string infeed_config() const { return infeed_config_; } - void set_infeed_config(const string& config) { infeed_config_ = config; } - - // Returns a tag to be used in tracing. - // - // Precondition: opcode() == HloOpcode::kTrace - string TracingTag() const; - - // Returns whether the instruction is a constant. - bool IsConstant() const; - // Returns true if this instruction is fused, ie contained within a fusion // instruction. bool IsFused() const; - // Returns the computation for this fused instruction. - // - // Precondition: opcode() == HloOpcode::kFusion - HloComputation* fused_instructions_computation() const; - // Returns true if this instruction can be legally fused into a fusion // instruction. bool IsFusable() const; - // Returns the root instruction of the fused expression contained within this - // fusion instruction. - // - // Precondition: opcode() == HloOpcode::kFusion - HloInstruction* fused_expression_root() const; - - // Returns the list of fused instructions inside this fusion instruction. The - // returned type is a range of HloInstruction*s. - // - // Precondition: opcode() == HloOpcode::kFusion - const tensorflow::gtl::iterator_range>::const_iterator>> - fused_instructions() const; - - const tensorflow::gtl::iterator_range< - UnwrappingIterator>::iterator>> - fused_instructions(); - - // Gets the number of instructions inside this fusion instruction. - // - // Precondition: opcode() == HloOpcode::kFusion - int64 fused_instruction_count() const; - - // Returns the fused parameter instruction in this fusion instruction - // corresponding to the given parameter number. - // - // Precondition: opcode() == HloOpcode::kFusion - HloInstruction* fused_parameter(int64 parameter_number) const; - - // Returns the vector of fused parameters inside this fusion instruction. - // - // Precondition: opcode() == HloOpcode::kFusion - const std::vector& fused_parameters() const; - - // Returns true if this instruction is a fusion instruction that generates - // multiple outputs. - const bool IsMultiOutputFusion() const { - return opcode() == HloOpcode::kFusion && - fused_expression_root()->opcode() == HloOpcode::kTuple; - } - - FusionKind fusion_kind() const { - CHECK_EQ(HloOpcode::kFusion, opcode_); - return fusion_kind_; - } - - void set_fusion_kind(FusionKind kind) { - CHECK_EQ(HloOpcode::kFusion, opcode_); - fusion_kind_ = kind; - } - // Returns the sharding applied to this operator. // REQUIRES: has_sharding() is true. const HloSharding& sharding() const { @@ -1124,8 +986,11 @@ class HloInstruction { void set_sharding(const HloSharding& sharding) { sharding_ = MakeUnique(sharding); } + void set_single_sharding(const HloSharding& sharding); // Sets a sharding that assigns the current instruction to device. - void set_device_sharding(int64 device); + void set_device_sharding(int64 device) { + set_single_sharding(HloSharding::AssignDevice(device)); + } // Remove any sharding from this operator. void clear_sharding() { sharding_ = nullptr; } // Return true if this operator has a sharding assigned. @@ -1155,167 +1020,17 @@ class HloInstruction { // instruction. void SetupDerivedInstruction(HloInstruction* derived_instruction) const; - // Adds a new operand the fusion instruction. - HloInstruction* AddFusionOperand(HloInstruction* new_operand); - - // Merges the fused instructions from 'instruction_to_merge' into the - // fused instruction set of 'this', updating operands as necessary. - // - // Precondition: opcode() == HloOpcode::kFusion - // Predondition: 'instruction_to_merge' must be an operand of 'this'. - void MergeFusionInstruction(HloInstruction* instruction_to_merge); - - // Merges the fused instructions from instruction_to_merge into the fused - // instruction set of 'this' and generates multioutput fusion instructions. - // All the users of instruction_to_merge will be redirected to 'this' - // instruction. instruction_to_merge will be removed from its parent - // computation. - // - // Precondition: opcode() == HloOpcode::kFusion - void MergeFusionInstructionIntoMultiOutput( - HloInstruction* instruction_to_merge); - - // Fuses the given instruction in this fusion instruction. instruction_to_fuse - // is cloned and the clone is placed in the fusion - // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather - // than moved to cleanly handle the case where the instruction has a use - // outside the fusion instruction. Moving such an instruction into a fusion - // instruction would violate the single-result invariant of HLO instructions - // and significantly complicate code generation. - // - // Precondition: this->opcode() == HloOpcode::kFusion - HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) { - return FuseInstructionInternal(instruction_to_fuse); + // TODO(b/80249101): Remove these methods once HLO scheduling and copy + // insertion are integrated, and we don't need to run a separate pass + // of copy elision anymore. + bool CopyElisionAllowed() const { + CHECK_EQ(HloOpcode::kCopy, opcode_); + return copy_elision_allowed_; } - // Fuses the given instruction in this fusion instruction and generate - // multioutput fusion instruction. A clone of the instruction_to_fuse will - // be part of the output of fusion instructions. The users of - // instruction_to_fuse will be redirected to this fusion instructions. - // instruction_to_fuse will be removed from its parent computation. - // - // Precondition: this->opcode() == HloOpcode::kFusion - HloInstruction* FuseInstructionIntoMultiOutput( - HloInstruction* instruction_to_fuse) { - return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true); - } - - // Returns the start index in the given dimension for a slice node. - // - // Precondition: opcode() == HloOpcode::kSlice - int64 slice_starts(int64 dimension) const { - CHECK_EQ(HloOpcode::kSlice, opcode_); - return slice_starts_[dimension]; - } - const std::vector& slice_starts() const { return slice_starts_; } - - // Returns the (exclusive) limit index in the given dimension for a slice - // node. - // - // Precondition: opcode() == HloOpcode::kSlice - int64 slice_limits(int64 dimension) const { - CHECK_EQ(HloOpcode::kSlice, opcode_); - return slice_limits_[dimension]; - } - const std::vector& slice_limits() const { - CHECK_EQ(HloOpcode::kSlice, opcode_); - return slice_limits_; - } - - // Returns the stride in the given dimension for a slice node. - // - // Precondition: opcode() == HloOpcode::kSlice - int64 slice_strides(int64 dimension) const { - CHECK_EQ(HloOpcode::kSlice, opcode_); - return slice_strides_[dimension]; - } - const std::vector& slice_strides() const { return slice_strides_; } - - // Returns the flag that describes whether a slice must be lowered into an - // offset into the original operand. - bool IsInPlaceSlice() const { return is_in_place_slice_; } - - // Sets and returns the flag that describes whether a slice must be lowered - // into an offset into the original operand. - bool SetIsInPlaceSlice(bool value) { - is_in_place_slice_ = value; - return value; - } - - // Returns the size of the slice in the given dimension for a dynamic - // slice node. - // - // Precondition: opcode() == HloOpcode::kDynamicSlice - int64 slice_sizes(int64 dimension) const { - CHECK_EQ(HloOpcode::kDynamicSlice, opcode_); - return dynamic_slice_sizes_[dimension]; - } - const std::vector& dynamic_slice_sizes() const { - CHECK_EQ(HloOpcode::kDynamicSlice, opcode_); - return dynamic_slice_sizes_; - } - - // Returns the number of exponent bits for a reduce-precision node. - // - // Precondition: opcode() == HloOpcode::kReducePrecision - int32 exponent_bits() const { - CHECK_EQ(HloOpcode::kReducePrecision, opcode_); - return exponent_bits_; - } - - // Returns the number of mantissa bits for a reduce-precision node. - // - // Precondition: opcode() == HloOpcode::kReducePrecision - int32 mantissa_bits() const { - CHECK_EQ(HloOpcode::kReducePrecision, opcode_); - return mantissa_bits_; - } - - // Returns data on the window in a windowed operation such as - // convolution. - const Window& window() const { - CHECK(window_ != nullptr); - return *window_; - } - - // Sets the window data in a windowed operation such as convolution. - void set_window(const Window& window) { - window_ = MakeUnique(window); - } - - // Returns the padding configuration for a pad node. - // - // Precondition: opcode() == HloOpcode::kPad - const PaddingConfig& padding_config() const { - CHECK(padding_config_ != nullptr); - return *padding_config_; - } - - // Returns data on the dimension numbers used for a convolution operation, - // which may be a kConvolution instruction or a kCustomCall that implements a - // convolution. - const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { - CHECK(convolution_dimension_numbers_ != nullptr); - return *convolution_dimension_numbers_; - } - - // Sets the convolution dimension numbers on this instruction. In general you - // shouldn't need to call this; instead, specify the convolution dimension - // numbers when you create the instruction. - void set_convolution_dimension_numbers( - const ConvolutionDimensionNumbers& dnums) { - convolution_dimension_numbers_ = - MakeUnique(dnums); - } - - FftType fft_type() const { - CHECK_EQ(HloOpcode::kFft, opcode_); - return fft_type_; - } - - const std::vector& fft_length() const { - CHECK_EQ(HloOpcode::kFft, opcode_); - return fft_length_; + void SetCopyElisionAllowed(bool value) { + CHECK_EQ(HloOpcode::kCopy, opcode_); + copy_elision_allowed_ = value; } // Returns data on the dimension numbers used for a dot operation. @@ -1340,11 +1055,6 @@ class HloInstruction { // Returns the dump string of the gather dimension numbers. string GatherDimensionNumbersToString() const; - // Returns the random distribution for this rng node. - // - // Precondition: opcode() == HloOpcode::kRng - RandomDistribution random_distribution() const; - // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of @@ -1355,7 +1065,8 @@ class HloInstruction { // Clones the HLO instruction as above but with new shape and operands. std::unique_ptr CloneWithNewOperands( - const Shape& shape, tensorflow::gtl::ArraySlice operands, + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context = nullptr) const; // Returns the computations this instruction directly calls (if any). @@ -1426,9 +1137,14 @@ class HloInstruction { std::tuple, std::vector> ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; - // Gets/sets the string identifier for this instruction. + // Gets the string identifier for this instruction. const string& name() const { return name_; } - void set_name(tensorflow::StringPiece name) { name_ = std::string(name); } + + // Sets the string identifier for this instruction. Name will be sanitized to + // match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". + void SetAndSanitizeName(const string& name) { + name_ = NameUniquer::GetSanitizedName(name); + } // Use the given NameUniquer to select a unique name for the instruction based // on the instruction's existing name. @@ -1509,18 +1225,269 @@ class HloInstruction { void set_outer_dimension_partitions( const std::vector& outer_dimension_partitions); - // Change the layout for an Constant Hlo instruction to match new_layout. For - // tuple shaped constants shape_index is the path to the internal array - // subshape whose layout needs to be changed. + // Old methods kept for smooth subclassing transition BEGIN. + // TODO(b/80131774): Remove this code. + + // Delegates to HloBatchNormInstruction::feature_index. + int64 feature_index() const; + + // Delegates to HloBatchNormInstruction::epsilon. + float epsilon() const; + + // Delegates to HloFftInstruction::fft_type. + FftType fft_type() const; + + // Delegates to HloFftInstruction::fft_length. + const std::vector& fft_length() const; + + // Delegates to HloSendRecvInstruction::channel_id. + int64 channel_id() const; + + // Returns the dimension sizes or numbers associated with this instruction. + virtual const std::vector& dimensions() const { + LOG(FATAL) << "Unimplemented method."; + } + virtual int64 dimensions(int64 index) const { + LOG(FATAL) << "Unimplemented method."; + } + + // Delegates to HloConcatenateInstruction::concatenate_dimension. + int64 concatenate_dimension() const; + + // Returns whether this instruction does a rank-2 transposition. + bool IsRank2Transpose() const; + + // Delegates to HloSliceInstruction::slice_start. + int64 slice_starts(int64 dimension) const; + const std::vector& slice_starts() const; + + // Delegates to HloSliceInstruction::slice_limits. + int64 slice_limits(int64 dimension) const; + const std::vector& slice_limits() const; + + // Delegates to HloSliceInstruction::slice_strides. + int64 slice_strides(int64 dimension) const; + const std::vector& slice_strides() const; + + // Delegates to HloSliceInstruction::IsInPlaceSlice. + bool IsInPlaceSlice() const; + + // Returns the literal associated with this instruction. + const Literal& literal() const; + + // Returns whether the instruction is a constant. + bool IsConstant() const; + + // Delegate to HloConstantInstruction::RelayoutConstant. void RelayoutConstant(const Layout& new_layout, const ShapeIndex& shape_index = {}); + // Delegates to HloTraceInstruction::TracingTag. + string TracingTag() const; + + // Delegates to HloFusionInstruction::AddFusionOperand. + HloInstruction* AddFusionOperand(HloInstruction* new_operand); + + // Delegates to HloFusionInstruction::MergeFusionInstruction. + void MergeFusionInstruction(HloInstruction* instruction_to_merge); + + // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput. + void MergeFusionInstructionIntoMultiOutput( + HloInstruction* instruction_to_merge); + + // Delegates to HloFusionInstruction::FuseInstruction. + HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse); + + // Delegates to HloFusionInstruction::FuseInstructionIntoMultiOutput. + HloInstruction* FuseInstructionIntoMultiOutput( + HloInstruction* instruction_to_fuse); + + // Delegates to HloFusionInstruction::fused_instruction. + HloComputation* fused_instructions_computation() const; + + // Delegates to HloFusionInstruction::fused_expression_root. + HloInstruction* fused_expression_root() const; + + // Delegates to HloFusionInstruction::fused_instructions. + const tensorflow::gtl::iterator_range>::const_iterator>> + fused_instructions() const; + + const tensorflow::gtl::iterator_range< + UnwrappingIterator>::iterator>> + fused_instructions(); + + // Delegates to HloFusionInstruction::fused_instruction_count. + int64 fused_instruction_count() const; + + // Delegates to HloFusionInstruction::fused_parameter. + HloInstruction* fused_parameter(int64 parameter_number) const; + + // Delegates to HloFusionInstruction::fused_parameters. + const std::vector& fused_parameters() const; + + // Returns true if this instruction is a fusion instruction that generates + // multiple outputs. + const bool IsMultiOutputFusion() const; + + // Delegates to HloFusionInstruction::fusion_kind. + FusionKind fusion_kind() const; + + // Delegates to HloFusionInstruction::set_fusion_kind. + void set_fusion_kind(FusionKind kind); + + // Delegates to HloRngInstruction::random_distribution. + RandomDistribution random_distribution() const; + + // Delegates to HloParameterInstruction::parameter_number. + int64 parameter_number() const; + + // Delegates to HloGetTupleElementInstruction::tuple_index. + int64 tuple_index() const; + + // Delegates to HloReducePrecisionInstruction::exponent_bits. + int32 exponent_bits() const; + + // Delegates to HloReducePrecisionInstruction::mantissa_bits. + int32 mantissa_bits() const; + + // Delegates to HloInfeedInstruction::infeed_config. + string infeed_config() const; + + // Delegates to HloInfeedInstruction::set_infeed_config. + void set_infeed_config(const string& config); + + // Returns the config for the Outfeed instruction. + const string& outfeed_config() const; + + // Returns the shape for the Outfeed instruction. + const Shape& outfeed_shape() const; + + // Delegates to HloAllReduceInstruction::replica_group_ids. + const std::vector& replica_group_ids() const; + + // Delegates to HloAllReduceInstruction::cross_replica_sum_barrier. + string cross_replica_sum_barrier() const; + void set_cross_replica_sum_barrier(const string& barrier); + + // Delegates to HloAllReduceInstruction::all_reduce_id. + tensorflow::gtl::optional all_reduce_id() const; + + // Returns data on the window in a windowed operation such as + // convolution. + virtual const Window& window() const { + LOG(FATAL) << "Unimplemented method."; + } + + // Sets the window data in a windowed operation such as convolution. + virtual void set_window(const Window& window) { + LOG(FATAL) << "Unimplemented method."; + } + + // Returns data on the dimension numbers used for a convolution operation, + // which may be a kConvolution instruction or a kCustomCall that implements a + // convolution. + const ConvolutionDimensionNumbers& convolution_dimension_numbers() const; + + // Sets the convolution dimension numbers on this instruction. In general you + // shouldn't need to call this; instead, specify the convolution dimension + // numbers when you create the instruction. + void set_convolution_dimension_numbers( + const ConvolutionDimensionNumbers& dnums); + + // Delegates to HloSelectAndScatterInstruction::select. + HloComputation* select() const; + + // Delegates to HloSelectAndScatterInstruction::scatter. + HloComputation* scatter() const; + + // Delegates to HloSelectAndScatterInstruction::set_select. + void set_select(HloComputation* computation); + + // Delegates to HloSelectAndScatterInstruction::set_scatter. + void set_scatter(HloComputation* computation); + + // Delegates to HloCustomCallInstruction::custom_call_target. + const string& custom_call_target() const; + + // Delegates to HloHostComputeInstruction::channel_name. + const string& channel_name() const; + + // Delegates to HloPadInstruction::padding_config. + const PaddingConfig& padding_config() const; + + // Delegates to HloDynamicSliceInstruction::slice_sizes. + int64 slice_sizes(int64 dimension) const; + + // Delegates to HloDynamicSliceInstruction::dynamic_slice_sizes. + const std::vector& dynamic_slice_sizes() const; + // Old methods kept for smooth subclassing transition END. + protected: + enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; + // Helper class for computing OperandElementUse for kFusion. + class FusionReusesParamElements; + // Internal constructor for a given opcode/shape, other fields must be filled // by factory methods. HloInstruction(HloOpcode opcode, const Shape& shape); + // Appends operand to the list of operands and adds this instruction as a user + // of the operand. + void AppendOperand(HloInstruction* operand); + + void RemoveOperandAt(int index) { + operands_.erase(operands_.begin() + index); + } + + void AppendComputation(HloComputation* computation) { + called_computations_.push_back(computation); + } + + void DetachFrom(HloInstruction* usee) { usee->RemoveUser(this); } + + void set_called_computation(int index, HloComputation* computation) { + called_computations_[index] = computation; + } + // Indices of computations in called_computations_ for instructions which call + // multiple computations. + enum { + // kWhile computations. + kBodyComputationIndex = 0, + kConditionComputationIndex = 1, + + // kSelectAndScatter computations. + kSelectComputationIndex = 0, + kScatterComputationIndex = 1, + + // kConditional computations. + kTrueComputationIndex = 0, + kFalseComputationIndex = 1, + }; + private: + // Implementation for non-common logic of CloneWithNewOperands. + virtual std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + // TODO(b/80131774): This should be pure virtual. + LOG(FATAL) << "Unimplemented method."; + } + + // Implementation for non-common logic of ExtraAttributesToString. + virtual std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {}; + } + + // Implementation for IsElementwise if operand_idx is nullopt and for + // IsElementwiseOnOperand if otherwise. + // + // NOTE: For all instructions other than kFusion, being elementwise on one of + // the operands is equivalent to being elementwise on all the operands. + virtual bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const; // Prints an instruction to a string. // // The canonical string representation needs to name operands and instruction @@ -1531,7 +1498,7 @@ class HloInstruction { CanonicalNameMap* canonical_name_map) const; // Prints an operand to a string. - string OperandsToStringWithCanonicalNameMap( + virtual string OperandsToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const; @@ -1539,13 +1506,8 @@ class HloInstruction { // OperandsToStringWithCanonicalNameMap() functions. friend class HloComputation; - enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; - - // Helper class for computing OperandElementUse for kFusion. - class FusionReusesParamElements; - // See comments on Identical(). - bool IdenticalSlowPath( + virtual bool IdenticalSlowPath( const HloInstruction& other, const std::function& eq_computations) const; @@ -1555,48 +1517,12 @@ class HloInstruction { const Shape& shape, HloOpcode opcode, tensorflow::gtl::ArraySlice operands); - // Appends operand to the list of operands and adds this instruction as a user - // of the operand. - void AppendOperand(HloInstruction* operand); - // Adds a user for this instruction. void AddUser(HloInstruction* user); // Removes a user for this instruction. void RemoveUser(HloInstruction* user); - // Fuses the given instruction into this fusion instruction. When add_output - // is false (which is the default), instruction_to_fuse is cloned and the - // clone is placed in the fusion instruction. instruction_to_fuse is - // unchanged. - // - // When add_output is true, a clone of the instruction_to_fuse will be part - // of the output of fusion instructions. The users of instruction_to_fuse - // will be redirected to this fusion instructions. instruction_to_fuse will - // be removed from its parent computation. - // - // Precondition: this->opcode() == HloOpcode::kFusion - HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse, - bool add_output = false); - - // Clones the given instruction_to_fuse and insert the clone into this fusion - // instruction. If add_output is true, a clone of instruction_to_fuse will - // be in the output of the this fusion instruction (part of the tuple of the - // fusion root). - // - // Precondition: opcode() == HloOpcode::kFusion - HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse, - bool add_output = false); - - // Clones a fusion instruction with a new shape and operands. - std::unique_ptr CloneFusionWithNewOperands( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloCloneContext* context = nullptr) const; - - // Returns true if this instruction can legally have the dimensions field - // set. Used for checking precondition of dimensions field accessors. - bool CanHaveDimensionsField() const; - // Returns how this instruction uses elements of its `i`th operand. UseKind OperandElementUse(int64 i) const; @@ -1628,62 +1554,17 @@ class HloInstruction { // The computation in which this instruction is contained. HloComputation* parent_ = nullptr; - // Shape of outfeed request. - Shape outfeed_shape_; - // Result shape of this instruction. Shape shape_; - // Literal, only present for kConstant. - std::unique_ptr literal_; - - // Constant index, only present for kGetTupleElement. - int64 tuple_index_ = -1; - - // Dimensions present for some operations that require reshaping or - // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse. - std::vector dimensions_; - - // Describes the window in a windowed operation such as convolution. - std::unique_ptr window_; - - // Describes the dimension numbers used for a convolution. - std::unique_ptr convolution_dimension_numbers_; - // Describes the dimension numbers used for a dot. std::unique_ptr dot_dimension_numbers_; std::unique_ptr gather_dimension_numbers_; std::vector gather_window_bounds_; - // Describes FFT type for an FFT instruction. - FftType fft_type_ = FftType::FFT; - - // Indicates the FFT length for an FFT instruction. - std::vector fft_length_; - - // Describes the [begin, end) index range for a slice. - std::vector slice_starts_; - std::vector slice_limits_; - std::vector slice_strides_; - - // Describes whether the slice can be lowered to an offset into the operand. - bool is_in_place_slice_ = false; - - // The bit sizes for a reduce-precision operation. - int32 exponent_bits_ = 0; - int32 mantissa_bits_ = 0; - - // Describes the [start, start + size) range size for a dynamic slice - // ('start' is specified dynamically in the second operand of the operation). - std::vector dynamic_slice_sizes_; - - // The padding configuration that describes the edge padding and interior - // padding of this pad instruction. Only set for pad instructions. - std::unique_ptr padding_config_; - - // The type of the fusion. Used by kFusion only. - FusionKind fusion_kind_; + // Used to tag kCopy instructions that are eligible for copy elision. + bool copy_elision_allowed_ = true; // The sharding, if one exists. std::unique_ptr sharding_; @@ -1692,65 +1573,15 @@ class HloInstruction { std::unique_ptr operand_side_metadata_; std::unique_ptr user_side_metadata_; - // For parameter instructions this field holds the parameter number. - int64 parameter_number_ = 0; - - // Name of a global symbol to call, only present for kCustomCall. - string custom_call_target_; - - // Name to use for host send/recv channels, only present for kHostCompute. - string channel_name_; - - // Estimate of the duration of a host computation in nanoseconds. - int64 cost_estimate_ns_ = 0; - // Computations called by this instruction. std::vector called_computations_; - // Indices of computations in called_computations_ for instructions which call - // multiple computations. - enum { - // kWhile computations. - kBodyComputationIndex = 0, - kConditionComputationIndex = 1, - - // kSelectAndScatter computations. - kSelectComputationIndex = 0, - kScatterComputationIndex = 1, - - // kConditional computations. - kTrueComputationIndex = 0, - kFalseComputationIndex = 1, - }; - - // Outfeed configuration information, only present for kOutfeed. - string outfeed_config_; - // A trace instruction that consumes this instruction. // // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as // an operand. HloInstruction* trace_instruction_ = nullptr; - // The distribution requested for random number generation. - // Only present for kRng. - RandomDistribution distribution_; - - // A small float number added to the variance to avoid divide-by-zero error. - // Only present for kBatchNormTraining. - float epsilon_ = 0.0f; - - // An integer value representing the index of the feature dimension. - // Only present for kBatchNormTraining. - int64 feature_index_ = -1; - - // Represents a unique identifier for each Send/Recv instruction pair. - // Only present for kSend or kRecv. - int64 channel_id_ = -1; - - // The string representation of the infeed configuration. - string infeed_config_; - // The backend-specific configuration for how a backend should compile this // HLO. See the documentation on backend_config(). string backend_config_; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 313033ddadce6a49936f8d34d38f33e923dc2e35..8ee24f9d92f61453a19a019c6e9c22ce37be1589 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -342,7 +342,7 @@ TEST_F(HloInstructionTest, TrivialMap) { // Builds a parameter and feeds it to the map. HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32a100x10, "")); + HloInstruction::CreateParameter(0, f32a100x10, "p")); auto map = builder.AddInstruction( HloInstruction::CreateMap(f32a100x10, {param0}, add_f32)); module->AddEntryComputation(builder.Build()); @@ -381,7 +381,7 @@ TEST_F(HloInstructionTest, TrivialReduce) { // Builds a parameter and an initial value and feeds them to the reduce. HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32a100x10, "")); + HloInstruction::CreateParameter(0, f32a100x10, "p")); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); builder.AddInstruction( @@ -923,6 +923,40 @@ TEST_F(HloInstructionTest, IdenticalInstructions) { *HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op1, op2))); } +TEST_F(HloInstructionTest, IdenticalCallInstructions) { + const char* const hlo_string = R"( +HloModule Module + +subcomp1 (x: f32[]) -> f32[] { + x = f32[] parameter(0) + ROOT n = f32[] sine(x) +} + +subcomp2 (x: f32[]) -> f32[] { + x = f32[] parameter(0) + ROOT n = f32[] cosine(x) +} + +ENTRY entry (param: f32[]) -> (f32[], f32[], f32[]) { + p = f32[] parameter(0) + t1 = f32[] call(p), to_apply=subcomp1 + t2 = f32[] call(p), to_apply=subcomp1 + t3 = f32[] call(p), to_apply=subcomp2 + ROOT t = (f32[], f32[], f32[]) tuple(t1, t2, t3) + } +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + auto* root = module->entry_computation()->root_instruction(); + auto* t1 = root->operand(0); + auto* t2 = root->operand(1); + auto* t3 = root->operand(2); + + EXPECT_TRUE(StructuralEqual(*t1, *t2)); + EXPECT_FALSE(StructuralEqual(*t1, *t3)); +} + TEST_F(HloInstructionTest, FunctionVisitor) { // Verify the function visitor HloInstruction::Accept visits all instructions // from a root properly given the following graph: @@ -980,6 +1014,23 @@ TEST_F(HloInstructionTest, FullyElementwise) { } } +TEST_F(HloInstructionTest, MapIsElementwise) { + auto module = CreateNewModule(); + const Shape r2f32 = ShapeUtil::MakeShapeWithLayout(F32, {10, 10}, {1, 0}); + HloComputation::Builder builder(TestName()); + HloComputation::Builder map_builder("id"); + map_builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); + auto map_computation = module->AddEmbeddedComputation(map_builder.Build()); + auto x = + builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "x")); + auto map = builder.AddInstruction( + HloInstruction::CreateMap(r2f32, {x}, map_computation)); + module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(map->IsElementwise()); +} + TEST_F(HloInstructionTest, PartiallyElementwise) { const Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); const Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 5}); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc new file mode 100644 index 0000000000000000000000000000000000000000..0b4ce715391d1e426b676a19958bbb9c6d1e4ffb --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -0,0 +1,1806 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_instructions.h" + +#include + +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/window_util.h" + +namespace xla { +namespace { + +using ::tensorflow::str_util::CEscape; +using ::tensorflow::str_util::Join; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + +bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, + const HloInstruction* operand) { + std::vector operand_indices = instruction->OperandIndices(operand); + return std::all_of( + operand_indices.begin(), operand_indices.end(), + [instruction](int64 operand_index) { + return instruction->IsElementwiseOnOperand(operand_index); + }); +} +} // namespace + +HloBatchNormInstruction::HloBatchNormInstruction( + HloOpcode opcode, const Shape& shape, HloInstruction* operand, + HloInstruction* scale, float epsilon, int64 feature_index) + : HloInstruction(opcode, shape), + epsilon_(epsilon), + feature_index_(feature_index) { + AppendOperand(operand); + AppendOperand(scale); +} + +bool HloBatchNormInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return feature_index() == casted_other.feature_index() && + epsilon() == casted_other.epsilon(); +} + +HloInstructionProto HloBatchNormInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_epsilon(epsilon_); + proto.set_feature_index(feature_index_); + return proto; +} + +std::vector HloBatchNormInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("epsilon=", epsilon()), + StrCat("feature_index=", feature_index())}; +} + +HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* offset, float epsilon, int64 feature_index) + : HloBatchNormInstruction(HloOpcode::kBatchNormTraining, shape, operand, + scale, epsilon, feature_index) { + AppendOperand(offset); +} + +std::unique_ptr +HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 3); + return MakeUnique( + shape, new_operands[0], new_operands[1], new_operands[2], epsilon(), + feature_index()); +} + +HloBatchNormInferenceInstruction::HloBatchNormInferenceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, + float epsilon, int64 feature_index) + : HloBatchNormInstruction(HloOpcode::kBatchNormInference, shape, operand, + scale, epsilon, feature_index) { + AppendOperand(offset); + AppendOperand(mean); + AppendOperand(variance); +} + +std::unique_ptr +HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 5); + return MakeUnique( + shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3], + new_operands[4], epsilon(), feature_index()); +} + +HloBatchNormGradInstruction::HloBatchNormGradInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* mean, HloInstruction* variance, HloInstruction* grad_output, + float epsilon, int64 feature_index) + : HloBatchNormInstruction(HloOpcode::kBatchNormGrad, shape, operand, scale, + epsilon, feature_index) { + AppendOperand(mean); + AppendOperand(variance); + AppendOperand(grad_output); +} + +std::unique_ptr +HloBatchNormGradInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 5); + return MakeUnique( + shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3], + new_operands[4], epsilon(), feature_index()); +} + +HloFftInstruction::HloFftInstruction( + const Shape& shape, HloInstruction* operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length) + : HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) { + fft_length_.assign(fft_length.begin(), fft_length.end()); + AppendOperand(operand); +} + +HloInstructionProto HloFftInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_fft_type(fft_type_); + for (int64 fft_len : fft_length_) { + proto.add_fft_length(fft_len); + } + return proto; +} + +std::vector HloFftInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("fft_type=", FftType_Name(fft_type())), + StrCat("fft_length={", Join(fft_length(), ","), "}")}; +} + +bool HloFftInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return fft_type() == casted_other.fft_type() && + fft_length() == casted_other.fft_length(); +} + +std::unique_ptr HloFftInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], fft_type_, + fft_length_); +} + +HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, + const Shape& shape, + int64 channel_id) + : HloInstruction(opcode, shape), channel_id_(channel_id) {} + +HloInstructionProto HloSendRecvInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_channel_id(channel_id_); + return proto; +} + +std::vector HloSendRecvInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("channel_id=", channel_id_)}; +} + +bool HloSendRecvInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + // Not yet supported. + return false; +} + +// Send instruction produces a tuple of {aliased operand, U32 context}. +HloSendInstruction::HloSendInstruction(HloInstruction* operand, + int64 channel_id) + : HloSendRecvInstruction( + HloOpcode::kSend, + ShapeUtil::MakeTupleShape( + {CHECK_NOTNULL(operand)->shape(), ShapeUtil::MakeShape(U32, {})}), + channel_id) { + AppendOperand(operand); +} + +std::unique_ptr HloSendInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(new_operands[0], channel_id()); +} + +HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand) + : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil(), + CHECK_NOTNULL(operand)->channel_id()) { + AppendOperand(operand); +} + +std::unique_ptr +HloSendDoneInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique( + Cast(new_operands[0])); +} + +// Recv instruction produces a tuple of {receive buffer, U32 context}. +HloRecvInstruction::HloRecvInstruction(const Shape& shape, int64 channel_id) + : HloSendRecvInstruction( + HloOpcode::kRecv, + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}), + channel_id) {} + +std::unique_ptr HloRecvInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 0); + return MakeUnique( + ShapeUtil::GetTupleElementShape(shape, 0), channel_id()); +} + +HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand) + : HloSendRecvInstruction( + HloOpcode::kRecvDone, + ShapeUtil::GetTupleElementShape(operand->shape(), 0), + CHECK_NOTNULL(operand)->channel_id()) { + AppendOperand(operand); +} + +std::unique_ptr +HloRecvDoneInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique( + Cast(new_operands[0])); +} + +HloAllReduceInstruction::HloAllReduceInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* reduce_computation, + tensorflow::gtl::ArraySlice replica_group_ids, + tensorflow::StringPiece barrier, + const tensorflow::gtl::optional& all_reduce_id) + : HloInstruction(HloOpcode::kCrossReplicaSum, shape), + replica_group_ids_(replica_group_ids.begin(), replica_group_ids.end()), + cross_replica_sum_barrier_(barrier.begin(), barrier.end()), + all_reduce_id_(all_reduce_id) { + // TODO(b/79737069): Remove the CHECK when supported. + CHECK(!all_reduce_id_); + for (auto operand : operands) { + AppendOperand(operand); + } + AppendComputation(reduce_computation); +} + +HloInstructionProto HloAllReduceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 i : replica_group_ids_) { + proto.add_replica_group_ids(i); + } + // Proto3 is so sad. + if (all_reduce_id_) { + proto.set_all_reduce_id(*all_reduce_id_); + } + proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_); + return proto; +} + +std::vector HloAllReduceInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& /*options*/) const { + std::vector result = { + StrCat("replica_group_ids={", Join(replica_group_ids(), ","), "}")}; + if (!cross_replica_sum_barrier().empty()) { + result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); + } + if (all_reduce_id_) { + result.push_back(StrCat("all_reduce_id=", *all_reduce_id_)); + } + return result; +} + +bool HloAllReduceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return replica_group_ids() == casted_other.replica_group_ids() && + eq_computations(to_apply(), casted_other.to_apply()) && + cross_replica_sum_barrier() == + casted_other.cross_replica_sum_barrier() && + all_reduce_id() == casted_other.all_reduce_id(); +} + +std::unique_ptr +HloAllReduceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* /*context*/) const { + return MakeUnique( + shape, new_operands, to_apply(), replica_group_ids(), + cross_replica_sum_barrier(), all_reduce_id()); +} + +HloReverseInstruction::HloReverseInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions) + : HloInstruction(HloOpcode::kReverse, shape), + dimensions_(dimensions.begin(), dimensions.end()) { + AppendOperand(operand); +} + +HloInstructionProto HloReverseInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloReverseInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloReverseInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr HloReverseInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], + dimensions()); +} + +HloConcatenateInstruction::HloConcatenateInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + int64 dimension) + : HloInstruction(HloOpcode::kConcatenate, shape), dimensions_({dimension}) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +HloInstructionProto HloConcatenateInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloConcatenateInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloConcatenateInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr +HloConcatenateInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(shape, new_operands, + dimensions(0)); +} + +HloReduceInstruction::HloReduceInstruction( + const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reduce_computation) + : HloInstruction(HloOpcode::kReduce, shape), + dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) { + AppendOperand(arg); + AppendOperand(init_value); + AppendComputation(reduce_computation); +} + +HloInstructionProto HloReduceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloReduceInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloReduceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + // Reduction results are determined by the reduction dimension and the + // reduction computation. + return dimensions() == casted_other.dimensions() && + eq_computations(to_apply(), casted_other.to_apply()); +} + +std::unique_ptr HloReduceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique( + shape, new_operands[0], new_operands[1], dimensions(), to_apply()); +} + +HloTransposeInstruction::HloTransposeInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions) + : HloInstruction(HloOpcode::kTranspose, shape), + dimensions_(dimensions.begin(), dimensions.end()) { + CHECK_EQ(shape.dimensions().size(), dimensions.size()); + CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size()); + CHECK(std::equal(operand->shape().dimensions().begin(), + operand->shape().dimensions().end(), + Permute(dimensions, shape.dimensions()).begin())) + << "shape: " << ShapeUtil::HumanString(shape) + << ", operand->shape(): " << ShapeUtil::HumanString(shape) + << ", dimensions: {" << Join(dimensions, ", ") << "}"; + AppendOperand(operand); +} + +bool HloTransposeInstruction::IsRank2Transpose() const { + return dimensions() == std::vector({1, 0}) && + shape().dimensions_size() == 2 && + std::equal(shape().dimensions().begin(), shape().dimensions().end(), + operand(0)->shape().dimensions().rbegin()); +} + +HloInstructionProto HloTransposeInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloTransposeInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloTransposeInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr +HloTransposeInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], + dimensions()); +} + +HloBroadcastInstruction::HloBroadcastInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice broadcast_dimension) + : HloInstruction(HloOpcode::kBroadcast, shape), + dimensions_(broadcast_dimension.begin(), broadcast_dimension.end()) { + AppendOperand(operand); +} + +HloInstructionProto HloBroadcastInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloBroadcastInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloBroadcastInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr +HloBroadcastInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], + dimensions()); +} + +HloMapInstruction::HloMapInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* map_computation, + tensorflow::gtl::ArraySlice static_operands) + : HloInstruction(HloOpcode::kMap, shape) { + CHECK(static_operands.empty()) << "static_operands not yet supported"; + for (auto operand : operands) { + AppendOperand(operand); + } + AppendComputation(map_computation); + // TODO(b/65689298) Remove code below once Map is generalized to accept + // arbitrary map dimensions. + dimensions_.resize(ShapeUtil::Rank(shape)); + std::iota(dimensions_.begin(), dimensions_.end(), 0); +} + +HloInstructionProto HloMapInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +bool HloMapInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + if (!dimensions().empty()) { + // Check that the map is executed in elementwise compatible dimensions. + if (dimensions().size() != shape().dimensions_size()) { + return false; + } + for (int i = 0; i < dimensions().size(); ++i) { + if (dimensions()[i] != i) { + return false; + } + } + } + return true; +} + +std::vector HloMapInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloMapInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return eq_computations(to_apply(), other.to_apply()); +} + +std::unique_ptr HloMapInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(shape, new_operands, to_apply()); +} + +HloSliceInstruction::HloSliceInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides) + : HloInstruction(HloOpcode::kSlice, shape), + slice_starts_(start_indices.begin(), start_indices.end()), + slice_limits_(limit_indices.begin(), limit_indices.end()), + slice_strides_(strides.begin(), strides.end()) { + AppendOperand(operand); + // For backward compatibility with old serialized computations: if there are + // no strides, assume all strides are 1. + // TODO(b/63317920): remove this code. + if (slice_strides_.empty()) { + slice_strides_ = std::vector(start_indices.size(), 1LL); + } +} + +HloInstructionProto HloSliceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int i = 0; i < slice_starts_.size(); ++i) { + auto* slice_dimension = proto.add_slice_dimensions(); + slice_dimension->set_start(slice_starts_[i]); + slice_dimension->set_limit(slice_limits_[i]); + slice_dimension->set_stride(slice_strides_[i]); + } + return proto; +} + +std::vector HloSliceInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector bounds; + bounds.reserve(slice_starts_.size()); + const bool omit_stride = + std::all_of(slice_strides_.begin(), slice_strides_.end(), + [](int64 stride) { return stride == 1; }); + for (int i = 0; i < slice_starts_.size(); ++i) { + string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]); + bounds.push_back( + StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]")); + } + return {StrCat("slice={", Join(bounds, ", "), "}")}; +} + +bool HloSliceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& other_slice = static_cast(other); + return slice_starts_ == other_slice.slice_starts_ && + slice_limits_ == other_slice.slice_limits_ && + slice_strides_ == other_slice.slice_strides_; +} + +std::unique_ptr HloSliceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], slice_starts_, + slice_limits_, slice_strides_); +} + +HloConstantInstruction::HloConstantInstruction(std::unique_ptr literal) + : HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()), + literal_(std::move(literal)) {} + +HloConstantInstruction::HloConstantInstruction(const Shape& shape) + : HloInstruction(HloOpcode::kConstant, shape) {} + +HloInstructionProto HloConstantInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + if (literal_ != nullptr) { + *proto.mutable_literal() = literal_->ToProto(); + } + return proto; +} + +bool HloConstantInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + return true; +} + +void HloConstantInstruction::RelayoutConstant(const Layout& new_layout, + const ShapeIndex& shape_index) { + Shape* mutable_array_subshape = + ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index); + CHECK(ShapeUtil::IsArray(*mutable_array_subshape)); + + // Normally array_subshape will always have a layout, but this invariant is + // temporarily broken in LayoutAssignment::AssignLayouts. + + if (!mutable_array_subshape->has_layout() || + !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) { + literal_ = literal_->Relayout(new_layout, shape_index); + *mutable_array_subshape->mutable_layout() = new_layout; + } +} + +bool HloConstantInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& other_slice = static_cast(other); + return literal() == other_slice.literal(); +} + +std::unique_ptr +HloConstantInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(literal_->CloneToUnique()); +} + +string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const { + string operands; + // For constants, show the actual value in place of an empty operand list. + if (literal_ != nullptr && + ((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) || + options.print_large_constants())) { + // Literal::ToString emits multidimensional arrays over multiple + // lines. Compact this into one line by stripping out white space. + string tmp = literal().ToString(); + std::replace(tmp.begin(), tmp.end(), '\n', ' '); + std::vector v = tensorflow::str_util::Split(tmp, ' '); + bool first = true; + // Concatenate elements in "v" with spaces separating them, but ignoring + // empty entries. + for (const auto& s : v) { + if (s.empty()) { + continue; + } + StrAppend(&operands, (first ? "" : " "), s); + first = false; + } + } else { + // Do not show large constants or tuples. + operands = "{...}"; + } + return operands; +} + +HloTraceInstruction::HloTraceInstruction(const string& tag, + HloInstruction* operand) + : HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()), + literal_(Literal::CreateR1U8(tag)) { + AppendOperand(operand); + operand->set_tracing(this); +} + +HloInstructionProto HloTraceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_literal() = literal_->ToProto(); + return proto; +} + +bool HloTraceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return false; +} + +std::unique_ptr HloTraceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode()); +} + +HloFusionInstruction::HloFusionInstruction(const Shape& shape, + FusionKind fusion_kind, + HloInstruction* fused_root) + : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) { + CHECK(fused_root != nullptr); + SetAndSanitizeName("fusion"); + set_parent(fused_root->parent()); + set_metadata(fused_root->metadata()); + CloneAndFuseInternal(fused_root); +} + +HloFusionInstruction::HloFusionInstruction( + const Shape& shape, FusionKind fusion_kind, + tensorflow::gtl::ArraySlice operands, + HloComputation* fusion_computation) + : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) { + for (auto operand : operands) { + AppendOperand(operand); + } + SetAndSanitizeName("fusion"); + AppendComputation(fusion_computation); + fusion_computation->SetFusionInstruction(this); +} + +string HloFusionInstruction::ToCategory() const { + switch (fusion_kind()) { + case FusionKind::kLoop: + return "loop fusion"; + case FusionKind::kInput: + return "input fusion"; + case FusionKind::kOutput: + return "output fusion"; + case FusionKind::kCustom: + return "custom fusion"; + } +} + +HloInstructionProto HloFusionInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_fusion_kind(xla::ToString(fusion_kind())); + proto.add_called_computation_ids( + fused_instructions_computation()->unique_id()); + return proto; +} + +bool HloFusionInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + if (fusion_kind() != FusionKind::kLoop) { + return false; + } + + if (!operand_idx.has_value()) { + for (auto* fused : fused_instructions()) { + if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) { + return false; + } + } + return true; + } + // A loop-fusion is elementwise on an operand if all operations (computed + // using BFS) between the operand and the fused root are elementwise. + std::deque worklist; + std::unordered_set visited; + worklist.push_back(fused_parameter(operand_idx.value())); + visited.insert(fused_parameter(operand_idx.value())); + while (!worklist.empty()) { + HloInstruction* operand = worklist.front(); + worklist.pop_front(); + for (HloInstruction* user : operand->users()) { + CHECK_GE(user->unique_id(), 0); + if (ContainsKey(visited, user)) { + continue; + } + if (user->IsElementwise() || + IsInstructionElementwiseOnOperand(user, operand)) { + worklist.push_back(user); + visited.insert(user); + } else { + return false; + } + } + } + return true; +} + +HloInstruction* HloFusionInstruction::AddFusionOperand( + HloInstruction* new_operand) { + CHECK_EQ(operand_count(), + fused_instructions_computation()->parameter_instructions().size()); + const int64 param_no = operand_count(); + // Name the parameter after the instruction it represents in the outer + // (non-fusion) computation. + string param_name = StrCat(new_operand->name(), ".param_", param_no); + HloInstruction* fused_parameter = + fused_instructions_computation()->AddParameter( + HloInstruction::CreateParameter(param_no, new_operand->shape(), + param_name)); + AppendOperand(new_operand); + return fused_parameter; +} + +void HloFusionInstruction::MergeFusionInstruction( + HloFusionInstruction* instruction_to_merge) { + CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) != + operands().end()); + // Clone the instruction from which to merge fused instructions. + std::unique_ptr cloned = instruction_to_merge->Clone(); + HloFusionInstruction* cloned_fusion = + static_cast(cloned.get()); + // Replace uses of fused parameters with the corresponding operand of the + // fusion. Add all non-parameter fused instructions to + // 'unfused_instructions' to be merged into 'this'. This is done in reverse + // post order. + std::vector unfused_instructions; + auto fused_instructions = cloned_fusion->fused_instructions_computation() + ->MakeInstructionPostOrder(); + for (auto fused_it = fused_instructions.rbegin(); + fused_it != fused_instructions.rend(); ++fused_it) { + auto fused_instruction = *fused_it; + if (fused_instruction->opcode() == HloOpcode::kParameter) { + TF_CHECK_OK( + fused_instruction->ReplaceAllUsesWith(cloned_fusion->mutable_operand( + fused_instruction->parameter_number()))); + } else { + unfused_instructions.push_back(fused_instruction); + } + } + CHECK(unfused_instructions.front() == cloned_fusion->fused_expression_root()); + // Replace instruction_to_merge use of 'this' with unfused_root. + TF_CHECK_OK( + instruction_to_merge->ReplaceUseWith(this, unfused_instructions.front())); + // Fuse 'unfused_instructions' into 'this'. + for (auto& instruction : unfused_instructions) { + FuseInstruction(instruction); + } + CHECK_EQ(0, cloned_fusion->user_count()); + TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation( + cloned_fusion->fused_instructions_computation())); +} + +void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput( + HloFusionInstruction* instruction_to_merge) { + // Add all non-parameter fused instructions to 'unfused_instructions' to be + // merged into 'this'. `old_to_new' maps the instructions in the fused node + // to the disaseembled fusion instructions. + // Note that we add the unfused instructions to this->parent_ computation. + // This is necessary because the unique_id needs for an instruction and + // it's only added when inserting to the computation. + tensorflow::gtl::FlatMap old_to_new; + std::vector unfused_instructions; + auto computation_to_merge = + instruction_to_merge->fused_instructions_computation(); + auto post_order = computation_to_merge->MakeInstructionPostOrder(); + for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) { + auto fused_instruction = *rit; + if (fused_instruction->opcode() == HloOpcode::kParameter) { + InsertOrDie(&old_to_new, fused_instruction, + instruction_to_merge->mutable_operand( + fused_instruction->parameter_number())); + continue; + } + + // Here we clone the insertion and call FuseInstructionIntoMultiOutput() + // which clones again. This can be improved. + auto cloned_instruction = + parent()->AddInstruction(fused_instruction->Clone()); + unfused_instructions.push_back(cloned_instruction); + InsertOrDie(&old_to_new, fused_instruction, cloned_instruction); + } + for (auto unfused_instruction : unfused_instructions) { + for (int64 index = 0; index < unfused_instruction->operand_count(); + index++) { + auto new_operand = + FindOrDie(old_to_new, unfused_instruction->mutable_operand(index)); + TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand)); + } + } + + HloInstruction* unfused_root = unfused_instructions.front(); + TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root)); + + TF_CHECK_OK( + instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge)); + if (GetModule()) { + TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge)); + } + + // Fuse the root instruction and generate multiple outputs. + FuseInstructionIntoMultiOutput(unfused_root); + TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root)); + // The rest instructions are of normal fusing. + for (int64 i = 1; i < unfused_instructions.size(); i++) { + auto instruction = unfused_instructions[i]; + FuseInstruction(instruction); + TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction)); + } +} + +HloComputation* HloFusionInstruction::fused_instructions_computation() const { + CHECK(!called_computations().empty()); + auto* fused_instructions_computation = called_computations().front(); + CHECK(fused_instructions_computation->IsFusionComputation()) + << "Computation " << fused_instructions_computation->name() + << " is not a fusion kind"; + return fused_instructions_computation; +} + +HloInstruction* HloFusionInstruction::fused_expression_root() const { + return fused_instructions_computation()->root_instruction(); +} + +HloInstruction* HloFusionInstruction::fused_parameter( + int64 parameter_number) const { + return fused_instructions_computation()->parameter_instruction( + parameter_number); +} + +const std::vector& HloFusionInstruction::fused_parameters() + const { + return fused_instructions_computation()->parameter_instructions(); +} + +const tensorflow::gtl::iterator_range>::const_iterator>> +HloFusionInstruction::fused_instructions() const { + const HloComputation* subcomp = fused_instructions_computation(); + return subcomp->instructions(); +} + +const tensorflow::gtl::iterator_range< + UnwrappingIterator>::iterator>> +HloFusionInstruction::fused_instructions() { + return fused_instructions_computation()->instructions(); +} + +int64 HloFusionInstruction::fused_instruction_count() const { + return fused_instructions_computation()->instruction_count(); +} + +HloInstruction* HloFusionInstruction::FuseInstructionInternal( + HloInstruction* instruction_to_fuse, bool add_output) { + // When add_output is false, this fusion instruction must be a user of + // instruction_to_fuse. + if (!add_output) { + CHECK(IsUserOf(instruction_to_fuse)); + } + HloInstruction* fused_instruction = + CloneAndFuseInternal(instruction_to_fuse, add_output); + return fused_instruction; +} + +HloInstruction* HloFusionInstruction::CloneAndFuseInternal( + HloInstruction* instruction_to_fuse, bool add_output) { + CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString(); + VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString(); + HloInstruction* clone = nullptr; + if (called_computations().empty()) { + // New fusion instruction. It should not be a multioutput instruction. + CHECK(!add_output); + auto builder = HloComputation::Builder("fused_computation", this); + builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/"")); + AppendComputation( + CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build())); + clone = fused_expression_root(); + } else { + clone = fused_instructions_computation()->AddInstruction( + instruction_to_fuse->Clone(/*suffix=*/"")); + // When add_output is false, instruction_to_fuse is necessarily an operand + // of the fusion instruction. After fusion this will no longer be the + // case. Remove the operand from the operand list and remove its + // corresponding fused parameter instruction. Renumber parameters as + // necessary to make parameter numbers consistent with their index in the + // fused_parameter_ vector. + bool in_operand_list = std::find(operands().begin(), operands().end(), + instruction_to_fuse) != operands().end(); + CHECK(add_output || in_operand_list); + const std::vector& fused_parameters = + fused_instructions_computation()->parameter_instructions(); + for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { + if (instruction_to_fuse == operand(operand_num)) { + // replace the fused parameter instruction's uses with the clone. + HloInstruction* fused_parameter = fused_parameters[operand_num]; + TF_CHECK_OK(fused_parameter->ReplaceAllUsesWith(clone)); + + // Remove the corresponding fused parameter and operand from their + // respective vectors. + TF_CHECK_OK( + fused_instructions_computation()->RemoveParameter(operand_num)); + RemoveOperandAt(operand_num); + break; + } + } + // We've cloned instruction_to_fuse into this fusion instruction, so this + // fusion instruction is no longer a use of instruction_to_fuse. + if (in_operand_list) { + DetachFrom(instruction_to_fuse); + // When the instruction_to_fuse does not have other users, we don't need + // to generate a multioutput fusion instruction. + if (instruction_to_fuse->user_count() == 0) { + add_output = false; + } + } + } + + // Reread the parameters in the computation. + const std::vector& fused_parameters = + fused_instructions_computation()->parameter_instructions(); + + // Add each operand of the clone as an operand of the fusion instruction. A + // complication is that some clone operands may already be operands of the + // fusion instruction. + for (int64 operand_num = 0; operand_num < clone->operand_count(); + ++operand_num) { + HloInstruction* operand = clone->mutable_operand(operand_num); + + // See if this operand is already an operand of the fusion node. + CHECK_EQ(operands().size(), fused_parameters.size()); + HloInstruction* fused_param = nullptr; + for (int64 i = 0; i < operands().size(); ++i) { + if (this->operand(i) == operand) { + fused_param = fused_parameters[i]; + break; + } + } + + if (fused_param == nullptr) { + // Clone's operand was not already an operand of the fusion + // instruction. Add it as an operand and add a corresponding fused + // parameter instruction. + fused_param = AddFusionOperand(operand); + } + TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param)); + } + + if (add_output) { + CHECK_GT(instruction_to_fuse->user_count(), 0); + // If this is already a multioutput fusion instruction, expand the root + // tuple by 1. + HloInstruction* fused_root = fused_expression_root(); + HloInstruction::InstructionVector tuple_elements; + bool newly_created_tuple_instr = false; + if (fused_root->opcode() == HloOpcode::kTuple) { + tuple_elements = fused_root->operands(); + } else { + tuple_elements.push_back(fused_root); + newly_created_tuple_instr = true; + } + if (clone->opcode() == HloOpcode::kTuple) { + for (auto inst : clone->operands()) { + tuple_elements.push_back(inst); + } + } else { + tuple_elements.push_back(clone); + } + HloInstruction* new_root = fused_instructions_computation()->AddInstruction( + HloInstruction::CreateTuple(tuple_elements)); + fused_instructions_computation()->set_root_instruction(new_root); + *mutable_shape() = new_root->shape(); + if (fused_root->opcode() == HloOpcode::kTuple) { + TF_CHECK_OK( + fused_instructions_computation()->RemoveInstruction(fused_root)); + } + + // If this is a newly created multioutput instruction, we need to update + // the use of the original fusion instruction. + if (newly_created_tuple_instr) { + HloInstruction* new_instr = parent()->AddInstruction( + HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0)); + TF_CHECK_OK(ReplaceAllUsesWith(new_instr)); + } + int64 index = tuple_elements.size(); + if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { + index -= instruction_to_fuse->operand_count(); + std::vector to_be_removed; + for (auto old_gte : instruction_to_fuse->users()) { + CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement); + int64 old_tuple_index = old_gte->tuple_index(); + HloInstruction* new_gte = + parent()->AddInstruction(HloInstruction::CreateGetTupleElement( + old_gte->shape(), this, index + old_tuple_index)); + TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte)); + to_be_removed.push_back(old_gte); + } + for (auto old_gte : to_be_removed) { + TF_CHECK_OK(parent()->RemoveInstruction(old_gte)); + } + TF_CHECK_OK(fused_instructions_computation()->RemoveInstruction(clone)); + } else { + HloInstruction* new_gte = + parent()->AddInstruction(HloInstruction::CreateGetTupleElement( + clone->shape(), this, index - 1)); + TF_CHECK_OK(instruction_to_fuse->ReplaceAllUsesWith(new_gte)); + } + } + + VLOG(2) << "New clone:\n" << clone->ToString(); + return clone; +} + +std::vector HloFusionInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("kind=", xla::ToString(fusion_kind()))}; +} + +bool HloFusionInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return fusion_kind() == other.fusion_kind() && + eq_computations(fused_instructions_computation(), + other.fused_instructions_computation()); +} + +std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + HloModule* module = context != nullptr ? context->module() : GetModule(); + HloComputation* new_fused_computation = nullptr; + if (context != nullptr) { + new_fused_computation = + context->FindComputation(fused_instructions_computation()); + } + if (new_fused_computation == nullptr) { + new_fused_computation = module->AddEmbeddedComputation( + fused_instructions_computation()->Clone("clone", context)); + } + return MakeUnique(shape, fusion_kind(), new_operands, + new_fused_computation); +} + +HloRngInstruction::HloRngInstruction( + const Shape& shape, RandomDistribution distribution, + tensorflow::gtl::ArraySlice parameters) + : HloInstruction(HloOpcode::kRng, shape), distribution_(distribution) { + for (HloInstruction* param : parameters) { + AppendOperand(param); + } +} + +HloInstructionProto HloRngInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_distribution(distribution_); + return proto; +} + +std::vector HloRngInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("distribution=", RandomDistributionToString(distribution_))}; +} + +bool HloRngInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + return true; +} + +bool HloRngInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return false; +} + +std::unique_ptr HloRngInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(shape, distribution_, new_operands); +} + +HloParameterInstruction::HloParameterInstruction(int64 parameter_number, + const Shape& shape, + const string& name) + : HloInstruction(HloOpcode::kParameter, shape), + parameter_number_(parameter_number) { + SetAndSanitizeName(name); +} + +HloInstructionProto HloParameterInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_parameter_number(parameter_number_); + return proto; +} + +string HloParameterInstruction::OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const { + return StrCat(parameter_number_); +} + +bool HloParameterInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return parameter_number() == casted_other.parameter_number(); +} + +std::unique_ptr +HloParameterInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(parameter_number_, shape, name()); +} + +HloGetTupleElementInstruction::HloGetTupleElementInstruction( + const Shape& shape, HloInstruction* operand, int64 index) + : HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) { + CHECK(ShapeUtil::IsTuple(operand->shape())); + AppendOperand(operand); +} + +HloInstructionProto HloGetTupleElementInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_tuple_index(tuple_index_); + return proto; +} + +std::vector HloGetTupleElementInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("index=", tuple_index())}; +} + +bool HloGetTupleElementInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + return tuple_index() == casted_other.tuple_index(); +} + +std::unique_ptr +HloGetTupleElementInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], + tuple_index()); +} + +HloReducePrecisionInstruction::HloReducePrecisionInstruction( + const Shape& shape, HloInstruction* operand, const int exponent_bits, + const int mantissa_bits) + : HloInstruction(HloOpcode::kReducePrecision, shape), + exponent_bits_(exponent_bits), + mantissa_bits_(mantissa_bits) { + AppendOperand(operand); +} + +HloInstructionProto HloReducePrecisionInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_exponent_bits(exponent_bits_); + proto.set_mantissa_bits(mantissa_bits_); + return proto; +} + +std::vector HloReducePrecisionInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("exponent_bits=", exponent_bits_), + StrCat("mantissa_bits=", mantissa_bits_)}; +} + +bool HloReducePrecisionInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + // A reduce-precision operation is determined by the bit sizes. + return exponent_bits() == casted_other.exponent_bits() && + mantissa_bits() == casted_other.mantissa_bits(); +} + +std::unique_ptr +HloReducePrecisionInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique( + shape, new_operands[0], exponent_bits(), mantissa_bits()); +} + +HloInfeedInstruction::HloInfeedInstruction(const Shape& shape, + const string& config) + : HloInstruction(HloOpcode::kInfeed, shape), infeed_config_(config) {} + +HloInstructionProto HloInfeedInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_infeed_config(infeed_config_); + return proto; +} + +std::vector HloInfeedInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + if (infeed_config_.empty()) { + return {}; + } + return {StrCat("infeed_config=\"", CEscape(infeed_config_), "\"")}; +} + +bool HloInfeedInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + // Not yet supported. + return false; +} + +std::unique_ptr HloInfeedInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 0); + return MakeUnique(shape, infeed_config()); +} + +HloOutfeedInstruction::HloOutfeedInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::StringPiece outfeed_config) + : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeNil()), + outfeed_shape_(shape), + outfeed_config_(outfeed_config.begin(), outfeed_config.end()) { + CHECK(ShapeUtil::Compatible(operand->shape(), shape)) + << "Outfeed shape " << shape << " must be compatible with operand shape " + << operand->shape(); + AppendOperand(operand); +} + +HloInstructionProto HloOutfeedInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_outfeed_config(outfeed_config()); + *proto.mutable_outfeed_shape() = outfeed_shape(); + return proto; +} + +std::vector HloOutfeedInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + if (outfeed_config_.empty()) { + return {}; + } + return {StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")}; +} + +bool HloOutfeedInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + // Not yet supported. + return false; +} + +std::unique_ptr HloOutfeedInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(outfeed_shape(), new_operands[0], + outfeed_config()); +} + +HloConvolutionInstruction::HloConvolutionInstruction( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const Window& window, const ConvolutionDimensionNumbers& dimension_numbers) + : HloInstruction(HloOpcode::kConvolution, shape), + window_(window), + convolution_dimension_numbers_(dimension_numbers) { + if (window_util::HasBaseDilation(window)) { + SetAndSanitizeName(StrCat(name(), "-base-dilated")); + } + if (window_util::HasWindowDilation(window)) { + SetAndSanitizeName(StrCat(name(), "-window-dilated")); + } + AppendOperand(lhs); + AppendOperand(rhs); +} + +string HloConvolutionInstruction::ToCategory() const { + string category = "convolution"; + if (window_util::HasBaseDilation(window())) { + category += " base-dilated"; + } + if (window_util::HasWindowDilation(window())) { + category += " window-dilated"; + } + return category; +} + +HloInstructionProto HloConvolutionInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_window() = window_; + *proto.mutable_convolution_dimension_numbers() = + convolution_dimension_numbers_; + return proto; +} + +std::vector HloConvolutionInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector extra; + if (window_.dimensions_size() != 0) { + extra.push_back(StrCat("window={", window_util::ToString(window()), "}")); + } + extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString( + convolution_dimension_numbers_))); + return extra; +} + +bool HloConvolutionInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + return protobuf_util::ProtobufEquals(window(), casted_other.window()) && + protobuf_util::ProtobufEquals( + convolution_dimension_numbers(), + casted_other.convolution_dimension_numbers()); +} + +std::unique_ptr +HloConvolutionInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique(shape, new_operands[0], + new_operands[1], window(), + convolution_dimension_numbers_); +} + +HloReduceWindowInstruction::HloReduceWindowInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* init_value, + const Window& window, HloComputation* reduce_computation) + : HloInstruction(HloOpcode::kReduceWindow, shape), window_(window) { + AppendOperand(operand); + AppendOperand(init_value); + AppendComputation(reduce_computation); +} + +HloInstructionProto HloReduceWindowInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_window() = window_; + return proto; +} + +std::vector HloReduceWindowInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector extra; + if (window_.dimensions_size() != 0) { + extra.push_back(StrCat("window={", window_util::ToString(window()), "}")); + } + return extra; +} + +bool HloReduceWindowInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + return eq_computations(to_apply(), casted_other.to_apply()) && + protobuf_util::ProtobufEquals(window(), casted_other.window()); +} + +std::unique_ptr +HloReduceWindowInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique( + shape, new_operands[0], new_operands[1], window(), to_apply()); +} + +HloSelectAndScatterInstruction::HloSelectAndScatterInstruction( + const Shape& shape, HloInstruction* operand, HloComputation* select, + const Window& window, HloInstruction* source, HloInstruction* init_value, + HloComputation* scatter) + : HloInstruction(HloOpcode::kSelectAndScatter, shape), window_(window) { + AppendOperand(operand); + AppendOperand(source); + AppendOperand(init_value); + // Select comes before scatter in the vector. + AppendComputation(select); + AppendComputation(scatter); +} + +HloInstructionProto HloSelectAndScatterInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_window() = window_; + return proto; +} + +std::vector HloSelectAndScatterInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector extra; + if (window_.dimensions_size() != 0) { + extra.push_back(StrCat("window={", window_util::ToString(window()), "}")); + } + return extra; +} + +bool HloSelectAndScatterInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + return eq_computations(select(), casted_other.select()) && + eq_computations(scatter(), casted_other.scatter()) && + protobuf_util::ProtobufEquals(window(), casted_other.window()); +} + +std::unique_ptr +HloSelectAndScatterInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 3); + return MakeUnique( + shape, new_operands[0], select(), window(), new_operands[1], + new_operands[2], scatter()); +} + +HloCustomCallInstruction::HloCustomCallInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece custom_call_target) + : HloInstruction(HloOpcode::kCustomCall, shape), + custom_call_target_(custom_call_target.begin(), + custom_call_target.end()) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +HloInstructionProto HloCustomCallInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + if (window_ != nullptr) { + *proto.mutable_window() = *window_; + } + if (convolution_dimension_numbers_ != nullptr) { + *proto.mutable_convolution_dimension_numbers() = + *convolution_dimension_numbers_; + } + proto.set_custom_call_target(custom_call_target_); + return proto; +} + +std::vector HloCustomCallInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector extra; + if (window_ != nullptr && window_->dimensions_size() != 0) { + extra.push_back(StrCat("window={", window_util::ToString(*window_), "}")); + } + if (convolution_dimension_numbers_ != nullptr) { + extra.push_back(StrCat( + "dim_labels=", + ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_))); + } + // By contract, we print the custom call target even if + // options.print_subcomputation_mode() == kOff, because the call target is not + // an HloComputation. + extra.push_back( + StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); + return extra; +} + +bool HloCustomCallInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + if ((window_ == nullptr) != (casted_other.window_ == nullptr) || + (window_ != nullptr && + !protobuf_util::ProtobufEquals(*window_, *casted_other.window_))) { + return false; + } + if ((convolution_dimension_numbers_ == nullptr) != + (casted_other.convolution_dimension_numbers_ == nullptr) || + (convolution_dimension_numbers_ != nullptr && + !protobuf_util::ProtobufEquals( + convolution_dimension_numbers(), + casted_other.convolution_dimension_numbers()))) { + return false; + } + return custom_call_target_ == casted_other.custom_call_target_; +} + +std::unique_ptr +HloCustomCallInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + auto cloned = MakeUnique(shape, new_operands, + custom_call_target()); + if (window_ != nullptr) { + cloned->set_window(*window_); + } + if (convolution_dimension_numbers_ != nullptr) { + cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_); + } + return std::move(cloned); +} + +HloHostComputeInstruction::HloHostComputeInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) + : HloInstruction(HloOpcode::kHostCompute, shape), + channel_name_(channel_name.begin(), channel_name.end()), + cost_estimate_ns_(cost_estimate_ns) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +HloInstructionProto HloHostComputeInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_channel_name(channel_name_); + proto.set_cost_estimate_ns(cost_estimate_ns_); + return proto; +} + +bool HloHostComputeInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + // Not yet supported. + return false; +} + +std::unique_ptr +HloHostComputeInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique( + shape, new_operands, channel_name_, cost_estimate_ns_); +} + +HloPadInstruction::HloPadInstruction(const Shape& shape, + HloInstruction* operand, + HloInstruction* padding_value, + const PaddingConfig& padding_config) + : HloInstruction(HloOpcode::kPad, shape), padding_config_(padding_config) { + AppendOperand(operand); + AppendOperand(padding_value); +} + +HloInstructionProto HloPadInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_padding_config() = padding_config_; + return proto; +} + +std::vector HloPadInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("padding=", xla::PaddingConfigToString(padding_config_))}; +} + +bool HloPadInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return protobuf_util::ProtobufEquals(padding_config(), + casted_other.padding_config()); +} + +std::unique_ptr HloPadInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique(shape, new_operands[0], new_operands[1], + padding_config_); +} + +HloDynamicSliceInstruction::HloDynamicSliceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, + tensorflow::gtl::ArraySlice slice_sizes) + : HloInstruction(HloOpcode::kDynamicSlice, shape), + dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) { + AppendOperand(operand); + AppendOperand(start_indices); +} + +HloInstructionProto HloDynamicSliceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 slice_size : dynamic_slice_sizes_) { + proto.add_dynamic_slice_sizes(slice_size); + } + return proto; +} + +std::vector HloDynamicSliceInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return { + StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")}; +} + +bool HloDynamicSliceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return true; +} + +std::unique_ptr +HloDynamicSliceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique( + shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h new file mode 100644 index 0000000000000000000000000000000000000000..1a2e4ae0a587d889f3064e24f9cda61f34517818 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -0,0 +1,1099 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// All HloInstruction subclasses are put in this file. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { + +class HloBatchNormInstruction : public HloInstruction { + public: + // Returns feature_index field associated with the instruction. The index + // represents the index of the feature dimension. + int64 feature_index() const { return feature_index_; } + + // Returns a epsilon value associated with the instruction. The is a small + // number added to the variance to avoid divide-by-zero error. + float epsilon() const { return epsilon_; } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + protected: + explicit HloBatchNormInstruction(HloOpcode opcode, const Shape& shape, + HloInstruction* operand, + HloInstruction* scale, float epsilon, + int64 feature_index); + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // A small float number added to the variance to avoid divide-by-zero error. + float epsilon_ = 0.0f; + + // An integer value representing the index of the feature dimension. + int64 feature_index_ = -1; +}; + +class HloBatchNormTrainingInstruction : public HloBatchNormInstruction { + public: + explicit HloBatchNormTrainingInstruction(const Shape& shape, + HloInstruction* operand, + HloInstruction* scale, + HloInstruction* offset, + float epsilon, int64 feature_index); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloBatchNormInferenceInstruction : public HloBatchNormInstruction { + public: + explicit HloBatchNormInferenceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, + float epsilon, int64 feature_index); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloBatchNormGradInstruction : public HloBatchNormInstruction { + public: + explicit HloBatchNormGradInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* mean, HloInstruction* variance, + HloInstruction* grad_output, float epsilon, int64 feature_index); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloFftInstruction : public HloInstruction { + public: + explicit HloFftInstruction(const Shape& shape, HloInstruction* operand, + FftType fft_type, + tensorflow::gtl::ArraySlice fft_length); + FftType fft_type() const { return fft_type_; } + + const std::vector& fft_length() const { return fft_length_; } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // Describes FFT type for an FFT instruction. + FftType fft_type_ = FftType::FFT; + + // Indicates the FFT length for an FFT instruction. + std::vector fft_length_; +}; + +class HloSendRecvInstruction : public HloInstruction { + public: + // Returns the channel id associated with the instruction. The id is + // shared between each Send/Recv pair and is globally unique to identify each + // channel. + int64 channel_id() const { return channel_id_; } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + protected: + explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape, + int64 channel_id); + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Represents a unique identifier for each Send/Recv instruction pair. + int64 channel_id_; +}; + +class HloSendInstruction : public HloSendRecvInstruction { + public: + explicit HloSendInstruction(HloInstruction* operand, int64 channel_id); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloSendDoneInstruction : public HloSendRecvInstruction { + public: + explicit HloSendDoneInstruction(HloSendInstruction* operand); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloRecvInstruction : public HloSendRecvInstruction { + public: + explicit HloRecvInstruction(const Shape& shape, int64 channel_id); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloRecvDoneInstruction : public HloSendRecvInstruction { + public: + explicit HloRecvDoneInstruction(HloRecvInstruction* operand); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloAllReduceInstruction : public HloInstruction { + public: + explicit HloAllReduceInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* reduce_computation, + tensorflow::gtl::ArraySlice replica_group_ids, + tensorflow::StringPiece barrier, + const tensorflow::gtl::optional& all_reduce_id = + tensorflow::gtl::nullopt); + + // Returns the group ids of each replica for CrossReplicaSum op. + const std::vector& replica_group_ids() const { + return replica_group_ids_; + } + + // Returns the barrier config used for the CrossReplicaSum implementation of + // each backend. + string cross_replica_sum_barrier() const { + return cross_replica_sum_barrier_; + } + void set_cross_replica_sum_barrier(string barrier) { + cross_replica_sum_barrier_ = barrier; + } + + tensorflow::gtl::optional all_reduce_id() const { + return all_reduce_id_; + } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The group id of each replica for CrossReplicaSum. + std::vector replica_group_ids_; + + // The string representation of the barrier config used for CrossReplicaSum. + string cross_replica_sum_barrier_; + + // For Allreduce nodes from different modules, if they have the same + // all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will not be + // applied cross modules. + tensorflow::gtl::optional all_reduce_id_; +}; + +class HloReverseInstruction : public HloInstruction { + public: + explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloConcatenateInstruction : public HloInstruction { + public: + explicit HloConcatenateInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + int64 dimension); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Accessor for the dimension in which a concatenate HLO should occur. + int64 concatenate_dimension() const { return dimensions(0); } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloReduceInstruction : public HloInstruction { + public: + explicit HloReduceInstruction( + const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reduce_computation); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloTransposeInstruction : public HloInstruction { + public: + explicit HloTransposeInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns whether this instruction does a rank-2 transposition. + bool IsRank2Transpose() const; + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloBroadcastInstruction : public HloInstruction { + public: + explicit HloBroadcastInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice broadcast_dimension); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloMapInstruction : public HloInstruction { + public: + explicit HloMapInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* map_computation, + tensorflow::gtl::ArraySlice static_operands = {}); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const override; + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloSliceInstruction : public HloInstruction { + public: + explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); + + HloInstructionProto ToProto() const override; + + // Returns the start index in the given dimension for a slice node. + int64 slice_starts(int64 dimension) const { return slice_starts_[dimension]; } + const std::vector& slice_starts() const { return slice_starts_; } + + // Returns the (exclusive) limit index in the given dimension for a slice + // node. + int64 slice_limits(int64 dimension) const { return slice_limits_[dimension]; } + const std::vector& slice_limits() const { return slice_limits_; } + + // Returns the stride in the given dimension for a slice node. + int64 slice_strides(int64 dimension) const { + return slice_strides_[dimension]; + } + const std::vector& slice_strides() const { return slice_strides_; } + + // Returns the flag that describes whether a slice must be lowered into an + // offset into the original operand. + bool IsInPlaceSlice() const { return is_in_place_slice_; } + + // Sets and returns the flag that describes whether a slice must be lowered + // into an offset into the original operand. + bool SetIsInPlaceSlice(bool value) { + is_in_place_slice_ = value; + return value; + } + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // Describes the [begin, end) index range for a slice. + std::vector slice_starts_; + std::vector slice_limits_; + std::vector slice_strides_; + + // Describes whether the slice can be lowered to an offset into the operand. + bool is_in_place_slice_ = false; +}; + +class HloConstantInstruction : public HloInstruction { + public: + explicit HloConstantInstruction(std::unique_ptr literal); + // Used when the literal is too large and dropped. + explicit HloConstantInstruction(const Shape& shape); + // Returns the literal associated with this instruction. + const Literal& literal() const { return *literal_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + // Change the layout for an Constant Hlo instruction to match new_layout. For + // tuple shaped constants shape_index is the path to the internal array + // subshape whose layout needs to be changed. + void RelayoutConstant(const Layout& new_layout, + const ShapeIndex& shape_index = {}); + + private: + bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + string OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + // TODO(b/36360764): Remove unique_ptr wrapping. + std::unique_ptr literal_; +}; + +class HloTraceInstruction : public HloInstruction { + public: + explicit HloTraceInstruction(const string& tag, HloInstruction* operand); + // Returns a tag to be used in tracing. + string TracingTag() const { return literal_->GetR1U8AsString(); } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + // TODO(b/36360764): Remove unique_ptr wrapping. + std::unique_ptr literal_; +}; + +class HloFusionInstruction : public HloInstruction { + public: + explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, + HloInstruction* fused_root); + + explicit HloFusionInstruction( + const Shape& shape, FusionKind fusion_kind, + tensorflow::gtl::ArraySlice operands, + HloComputation* fusion_computation); + + string ToCategory() const override; + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + // Adds a new operand the fusion instruction. + HloInstruction* AddFusionOperand(HloInstruction* new_operand); + + // Merges the fused instructions from 'instruction_to_merge' into the + // fused instruction set of 'this', updating operands as necessary. + // + // Predondition: 'instruction_to_merge' must be an operand of 'this'. + void MergeFusionInstruction(HloFusionInstruction* instruction_to_merge); + + // Merges the fused instructions from instruction_to_merge into the fused + // instruction set of 'this' and generates multioutput fusion instructions. + // All the users of instruction_to_merge will be redirected to 'this' + // instruction. instruction_to_merge will be removed from its parent + // computation. + void MergeFusionInstructionIntoMultiOutput( + HloFusionInstruction* instruction_to_merge); + + // Fuses the given instruction in this fusion instruction. instruction_to_fuse + // is cloned and the clone is placed in the fusion + // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather + // than moved to cleanly handle the case where the instruction has a use + // outside the fusion instruction. Moving such an instruction into a fusion + // instruction would violate the single-result invariant of HLO instructions + // and significantly complicate code generation. + HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) { + return FuseInstructionInternal(instruction_to_fuse); + } + + // Fuses the given instruction in this fusion instruction and generate + // multioutput fusion instruction. A clone of the instruction_to_fuse will + // be part of the output of fusion instructions. The users of + // instruction_to_fuse will be redirected to this fusion instructions. + // instruction_to_fuse will be removed from its parent computation. + HloInstruction* FuseInstructionIntoMultiOutput( + HloInstruction* instruction_to_fuse) { + return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true); + } + + // Returns the computation for this fused instruction. + HloComputation* fused_instructions_computation() const; + + // Returns the root instruction of the fused expression contained within this + // fusion instruction. + HloInstruction* fused_expression_root() const; + + // Returns the list of fused instructions inside this fusion instruction. The + // returned type is a range of HloInstruction*s. + const tensorflow::gtl::iterator_range>::const_iterator>> + fused_instructions() const; + + const tensorflow::gtl::iterator_range< + UnwrappingIterator>::iterator>> + fused_instructions(); + + // Gets the number of instructions inside this fusion instruction. + int64 fused_instruction_count() const; + + // Returns the fused parameter instruction in this fusion instruction + // corresponding to the given parameter number. + HloInstruction* fused_parameter(int64 parameter_number) const; + + // Returns the vector of fused parameters inside this fusion instruction. + const std::vector& fused_parameters() const; + + // Returns true if this instruction is a fusion instruction that generates + // multiple outputs. + const bool IsMultiOutputFusion() const { + return fused_expression_root()->opcode() == HloOpcode::kTuple; + } + + FusionKind fusion_kind() const { return fusion_kind_; } + + void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; } + + private: + // Fuses the given instruction into this fusion instruction. When add_output + // is false (which is the default), instruction_to_fuse is cloned and the + // clone is placed in the fusion instruction. instruction_to_fuse is + // unchanged. + // + // When add_output is true, a clone of the instruction_to_fuse will be part + // of the output of fusion instructions. The users of instruction_to_fuse + // will be redirected to this fusion instructions. instruction_to_fuse will + // be removed from its parent computation. + HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse, + bool add_output = false); + // Clones the given instruction_to_fuse and insert the clone into this fusion + // instruction. If add_output is true, a clone of instruction_to_fuse will + // be in the output of the this fusion instruction (part of the tuple of the + // fusion root). + HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse, + bool add_output = false); + + bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const override; + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The type of the fusion. Used by kFusion only. + FusionKind fusion_kind_; +}; + +class HloRngInstruction : public HloInstruction { + public: + explicit HloRngInstruction( + const Shape& shape, RandomDistribution distribution, + tensorflow::gtl::ArraySlice parameters); + // Returns the random distribution for this rng node. + RandomDistribution random_distribution() const { return distribution_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const override; + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The distribution requested for random number generation. + RandomDistribution distribution_; +}; + +class HloParameterInstruction : public HloInstruction { + public: + explicit HloParameterInstruction(int64 parameter_number, const Shape& shape, + const string& name); + int64 parameter_number() const { return parameter_number_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + string OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + int64 parameter_number_ = 0; +}; + +class HloGetTupleElementInstruction : public HloInstruction { + public: + explicit HloGetTupleElementInstruction(const Shape& shape, + HloInstruction* operand, int64 index); + // Returns the tuple index associated with this instruction. + int64 tuple_index() const { return tuple_index_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + int64 tuple_index_ = -1; +}; + +class HloReducePrecisionInstruction : public HloInstruction { + public: + explicit HloReducePrecisionInstruction(const Shape& shape, + HloInstruction* operand, + const int exponent_bits, + const int mantissa_bits); + // Returns the number of exponent bits for a reduce-precision node. + int32 exponent_bits() const { return exponent_bits_; } + // Returns the number of mantissa bits for a reduce-precision node. + int32 mantissa_bits() const { return mantissa_bits_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The bit sizes for a reduce-precision operation. + int32 exponent_bits_ = 0; + int32 mantissa_bits_ = 0; +}; + +class HloInfeedInstruction : public HloInstruction { + public: + explicit HloInfeedInstruction(const Shape& shape, const string& config); + // Returns the infeed configuration string. The infeed configuration includes + // any metadata needed for the backend compiler (e.g., infeed buffer address) + // and is target-dependent. + string infeed_config() const { return infeed_config_; } + void set_infeed_config(const string& config) { infeed_config_ = config; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The string representation of the infeed configuration. + string infeed_config_; +}; + +class HloOutfeedInstruction : public HloInstruction { + public: + explicit HloOutfeedInstruction(const Shape& shape, HloInstruction* operand, + tensorflow::StringPiece outfeed_config); + // Returns the shape for the Outfeed instruction. + const Shape& outfeed_shape() const { + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape())); + return outfeed_shape_; + } + // Returns the config for the Outfeed instruction. + const string& outfeed_config() const { return outfeed_config_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // Shape of outfeed request. + Shape outfeed_shape_; + // Outfeed configuration information, only present for kOutfeed. + string outfeed_config_; +}; + +class HloConvolutionInstruction : public HloInstruction { + public: + explicit HloConvolutionInstruction( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers); + const Window& window() const override { return window_; } + void set_window(const Window& window) override { window_ = window; } + const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { + return convolution_dimension_numbers_; + } + void set_convolution_dimension_numbers( + const ConvolutionDimensionNumbers& dnums) { + convolution_dimension_numbers_ = dnums; + } + string ToCategory() const override; + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + Window window_; + // Describes the dimension numbers used for a convolution. + ConvolutionDimensionNumbers convolution_dimension_numbers_; +}; + +class HloReduceWindowInstruction : public HloInstruction { + public: + explicit HloReduceWindowInstruction(const Shape& shape, + HloInstruction* operand, + HloInstruction* init_value, + const Window& window, + HloComputation* reduce_computation); + const Window& window() const override { return window_; } + void set_window(const Window& window) override { window_ = window; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + Window window_; +}; + +class HloSelectAndScatterInstruction : public HloInstruction { + public: + explicit HloSelectAndScatterInstruction( + const Shape& shape, HloInstruction* operand, HloComputation* select, + const Window& window, HloInstruction* source, HloInstruction* init_value, + HloComputation* scatter); + const Window& window() const override { return window_; } + void set_window(const Window& window) override { window_ = window; } + // Gets/sets the select or scatter HloComputation for SelectAndScatter. The + // setters should only be called by HloModule or HloComputation methods. + HloComputation* select() const { + return called_computations()[kSelectComputationIndex]; + } + + HloComputation* scatter() const { + return called_computations()[kScatterComputationIndex]; + } + + void set_select(HloComputation* computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); + set_called_computation(kSelectComputationIndex, computation); + } + + void set_scatter(HloComputation* computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); + set_called_computation(kScatterComputationIndex, computation); + } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + Window window_; +}; + +class HloCustomCallInstruction : public HloInstruction { + public: + explicit HloCustomCallInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece custom_call_target); + const Window& window() const override { + CHECK(window_ != nullptr); + return *window_; + } + + void set_window(const Window& window) override { + window_ = MakeUnique(window); + } + + const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { + CHECK(convolution_dimension_numbers_ != nullptr); + return *convolution_dimension_numbers_; + } + + void set_convolution_dimension_numbers( + const ConvolutionDimensionNumbers& dnums) { + convolution_dimension_numbers_ = + MakeUnique(dnums); + } + const string& custom_call_target() const { return custom_call_target_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + // Name of a global symbol to call, only present for kCustomCall. + string custom_call_target_; + // Describes the window in a windowed operation such as convolution. + std::unique_ptr window_; + // Describes the dimension numbers used for a convolution. + std::unique_ptr convolution_dimension_numbers_; +}; + +class HloHostComputeInstruction : public HloInstruction { + public: + explicit HloHostComputeInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece channel_name, const int64 cost_estimate_ns); + // Returns the channel name associated with the instruction. The name is + // used to identify host Send/Recv operations. + const string& channel_name() const { return channel_name_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + // Name to use for host send/recv channels. + string channel_name_; + // Estimate of the duration of a host computation in nanoseconds. + int64 cost_estimate_ns_ = 0; +}; + +class HloPadInstruction : public HloInstruction { + public: + explicit HloPadInstruction(const Shape& shape, HloInstruction* operand, + HloInstruction* padding_value, + const PaddingConfig& padding_config); + // Returns the padding configuration for a pad node. + const PaddingConfig& padding_config() const { return padding_config_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The padding configuration that describes the edge padding and interior + // padding of this pad instruction. + PaddingConfig padding_config_; +}; + +class HloDynamicSliceInstruction : public HloInstruction { + public: + explicit HloDynamicSliceInstruction( + const Shape& shape, HloInstruction* operand, + HloInstruction* start_indices, + tensorflow::gtl::ArraySlice slice_sizes); + // Old methods kept for smooth subclassing transition END. + // Returns the size of the slice in the given dimension for a dynamic + // slice node. + int64 slice_sizes(int64 dimension) const { + return dynamic_slice_sizes_[dimension]; + } + const std::vector& dynamic_slice_sizes() const { + return dynamic_slice_sizes_; + } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // Describes the [start, start + size) range size for a dynamic slice + // ('start' is specified dynamically in the second operand of the operation). + std::vector dynamic_slice_sizes_; +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index c570b420c21fed4d7828feb24ee5c7859db94a79..8a31a8e617c1fb82201e07d9a3ff1ab9a618206b 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -187,6 +187,7 @@ HLO_MATCHER(Exp); HLO_MATCHER(Floor); HLO_MATCHER(Fusion); HLO_MATCHER(Ge); +HLO_MATCHER(GenerateToken); HLO_MATCHER(Gt); HLO_MATCHER(Infeed); HLO_MATCHER(IsFinite); diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index e63424c2dfb6c7b9e71e4cede896a8f6609fea62..39bc25ba42c2cb6a9f77e2726405311ba13b3edc 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -32,15 +32,6 @@ limitations under the License. namespace xla { -HloModule::HloModule(const string& name, - const VersionedComputationHandle& entry_computation_handle, - const HloModuleConfig& config) - : name_(NameUniquer::GetSanitizedName(name)), - config_(config), - has_entry_computation_handle_(true), - entry_computation_handle_(entry_computation_handle), - unique_id_(next_unique_module_id_++) {} - HloModule::HloModule(const string& name, const HloModuleConfig& config) : name_(NameUniquer::GetSanitizedName(name)), config_(config), @@ -67,7 +58,7 @@ HloComputation* HloModule::AddComputationInternal( // If the module configuration has no entry layout computation set, create a // default one based on the program shape. - if (!config_.has_host_entry_computation_layout()) { + if (!config_.has_entry_computation_layout()) { config_.SetDefaultComputationLayout( entry_computation_->ComputeProgramShape()); } @@ -234,21 +225,17 @@ HloModuleProto HloModule::ToProto() const { /* static */ StatusOr> HloModule::CreateFromProto( - const HloModuleProto& proto, const HloModuleConfig& module_config, - const VersionedComputationHandle& entry_computation_handle) { + const HloModuleProto& proto, const HloModuleConfig& module_config) { // The ProgramShape in the passed in module config must match the shapes of // the entry parameters and root. TF_RET_CHECK(proto.has_program_shape()) << "No program shape found in the proto"; const auto& expected_program_shape = proto.program_shape(); - TF_RET_CHECK( - expected_program_shape.parameters_size() == - module_config.device_entry_computation_layout().parameter_count()); + TF_RET_CHECK(expected_program_shape.parameters_size() == + module_config.entry_computation_layout().parameter_count()); for (int i = 0; i < expected_program_shape.parameters_size(); ++i) { const Shape& parameter_shape = - module_config.device_entry_computation_layout() - .parameter_layout(i) - .shape(); + module_config.entry_computation_layout().parameter_layout(i).shape(); TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i), parameter_shape)) << "HloModuleConfig has different shape for parameter " << i @@ -258,7 +245,7 @@ StatusOr> HloModule::CreateFromProto( << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape); } const Shape& result_shape = - module_config.device_entry_computation_layout().result_layout().shape(); + module_config.entry_computation_layout().result_layout().shape(); TF_RET_CHECK( ShapeUtil::Compatible(expected_program_shape.result(), result_shape)) << "HloModuleConfig has different result shape than the HLO module. " @@ -287,8 +274,7 @@ StatusOr> HloModule::CreateFromProto( } TF_RET_CHECK(entry != nullptr); - auto module = MakeUnique(proto.name(), entry_computation_handle, - module_config); + auto module = MakeUnique(proto.name(), module_config); // Sort the computations in the proto id's order. std::sort(computations.begin(), computations.end(), @@ -338,7 +324,7 @@ StatusOr HloModule::CreateModuleConfigFromProto( // The module config is constructed with default layouts regardless of what is // passed in via the ProgramShape. Set the layouts to the appropriate values. ComputationLayout* entry_layout = - module_config.mutable_host_entry_computation_layout(); + module_config.mutable_entry_computation_layout(); for (int64 i = 0; i < entry_layout->parameter_count(); ++i) { TF_RETURN_IF_ERROR( entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( @@ -346,9 +332,6 @@ StatusOr HloModule::CreateModuleConfigFromProto( } TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape( program_shape.result())); - *module_config.mutable_device_entry_computation_layout() = - module_config.host_entry_computation_layout(); - return module_config; } @@ -401,7 +384,7 @@ HloInstruction* HloModule::OutlineExpressionFromComputation( // as a parameter in the new function. arguments.push_back(old_operand); *operand_slot = builder.AddInstruction(HloInstruction::CreateParameter( - parameter_count, old_operand->shape(), "")); + parameter_count, old_operand->shape(), "p")); ++parameter_count; } TF_CHECK_OK( @@ -462,7 +445,7 @@ int64 HloModule::instruction_count() const { return n; } -std::list HloModule::MakeComputationPostOrder() const { +std::vector HloModule::MakeComputationPostOrder() const { // First determine all root computations by building a set of nonroot // computations (computations which are called by an instruction in the // module). @@ -480,7 +463,7 @@ std::list HloModule::MakeComputationPostOrder() const { // order. This prevents duplication as an embedded computation may be called // from two different root computations. std::set added_computations; - std::list post_order; + std::vector post_order; for (auto& computation : computations_) { if (nonroot_computations.count(computation.get()) == 0) { for (HloComputation* embedded_computation : @@ -525,8 +508,6 @@ std::vector HloModule::MakeNonfusionComputations() const { std::unique_ptr HloModule::Clone(const string& suffix) const { VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; auto module = MakeUnique(name_ + "-" + suffix, config_); - module->entry_computation_handle_ = entry_computation_handle_; - module->has_entry_computation_handle_ = has_entry_computation_handle_; HloCloneContext context(module.get(), suffix); auto cloned_computation = entry_computation_->Clone(suffix, &context); diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index c93c74d34a95cfbb3d0d334fb1c1f40a5aad69e9..d2e726a0db63f622cd5092d56b4f746232d04aad 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -57,10 +56,6 @@ namespace xla { // attached to. class HloModule { public: - HloModule(const string& name, - const VersionedComputationHandle& entry_computation_handle, - const HloModuleConfig& config); - // Constructor without a versioned computation handle. This constructor should // only be used for HloModules used outside of the XLA service (eg // tests). The versioned handle is used by the service in the compilation @@ -110,24 +105,19 @@ class HloModule { return entry_computation_; } - ComputationLayout* mutable_host_entry_computation_layout() { - return config_.mutable_host_entry_computation_layout(); - } - - const ComputationLayout& host_entry_computation_layout() const { - return config_.host_entry_computation_layout(); + // Creates the ComputationLayout which describes the current status of the HLO + // module entry computation. + ComputationLayout compute_computation_layout() const { + return ComputationLayout(entry_computation()->ComputeProgramShape(), + /*ignore_layouts=*/false); } - ComputationLayout* mutable_device_entry_computation_layout() { - return config_.mutable_device_entry_computation_layout(); + ComputationLayout* mutable_entry_computation_layout() { + return config_.mutable_entry_computation_layout(); } - const ComputationLayout& device_entry_computation_layout() const { - return config_.device_entry_computation_layout(); - } - - const VersionedComputationHandle& entry_computation_handle() const { - return entry_computation_handle_; + const ComputationLayout& entry_computation_layout() const { + return config_.entry_computation_layout(); } // Gets the computations in this module. @@ -163,7 +153,7 @@ class HloModule { // Compute and return a post order of all computations in the module. The sort // is defined like so: if computation A has an instruction which calls // computation B, then A will appear after B in the sort. - std::list MakeComputationPostOrder() const; + std::vector MakeComputationPostOrder() const; // Gets the computations in this module which aren't for fusion nodes. // @@ -188,9 +178,7 @@ class HloModule { // Convert an HloModule to or from a proto. HloModuleProto ToProto() const; static StatusOr> CreateFromProto( - const HloModuleProto& proto, const HloModuleConfig& module_config, - const VersionedComputationHandle& entry_computation_handle = - VersionedComputationHandle()); + const HloModuleProto& proto, const HloModuleConfig& module_config); // Creates and returns an HloModuleConfig with an appropriate program shape // for the HLO module in the given proto. @@ -264,10 +252,6 @@ class HloModule { mutable std::mt19937_64 rng_{42}; mutable tensorflow::mutex rng_mutex_; - // Versioned handle of the entry computation of the module. - bool has_entry_computation_handle_ = false; - VersionedComputationHandle entry_computation_handle_; - // Unique name generator for computation and instruction names, which are // unique per module. NameUniquer computation_name_uniquer_{/*separator=*/"."}; diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index dae5578a3158fecb8219e518841dec1020b2ca98..07a8c798dbee072db3b75d5e99ca0dcabb5fdf6b 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -28,16 +28,14 @@ namespace xla { using tensorflow::strings::StrAppend; -HloModuleConfig::HloModuleConfig() {} - -HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape) - : host_entry_computation_layout_(program_shape), - device_entry_computation_layout_(program_shape) {} +HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape, + bool ignore_layouts) + : entry_computation_layout_( + ComputationLayout(program_shape, ignore_layouts)) {} void HloModuleConfig::SetDefaultComputationLayout( const ProgramShape& program_shape) { - host_entry_computation_layout_ = ComputationLayout(program_shape); - device_entry_computation_layout_ = ComputationLayout(program_shape); + entry_computation_layout_ = ComputationLayout(program_shape); } string HloModuleConfig::compilation_cache_key() const { @@ -46,18 +44,11 @@ string HloModuleConfig::compilation_cache_key() const { StrAppend(&key, "::("); std::vector params; for (const ShapeLayout& param_layout : - host_entry_computation_layout_->parameter_layouts()) { + entry_computation_layout_->parameter_layouts()) { params.push_back(param_layout.shape().DebugString()); } StrAppend(&key, tensorflow::str_util::Join(params, ", "), ") => ", - host_entry_computation_layout_->result_shape().SerializeAsString()); - for (const ShapeLayout& param_layout : - device_entry_computation_layout_->parameter_layouts()) { - params.push_back(param_layout.shape().DebugString()); - } - StrAppend( - &key, tensorflow::str_util::Join(params, ", "), ") => ", - device_entry_computation_layout_->result_shape().SerializeAsString()); + entry_computation_layout_->result_shape().SerializeAsString()); if (seed() != 0) { // TODO(b/32083678): force recompilation to reset global state. static std::atomic counter{0}; diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index cdb0b29a2399b387bc617262032e9083ba079625..074e9c90705d432b8344aebaf3c15aeb41a59fa3 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -37,48 +37,34 @@ class HloModuleConfig { // ComputationLayout. The default ctor creates it without -- in this case // accessing entry_computation_layout will CHECK-fail. The ctor accepting a // ProgramShape creates a computation layout using this shape. - HloModuleConfig(); - explicit HloModuleConfig(const ProgramShape& program_shape); + // The layouts in the ProgramShape will be reset to default unless + // ignore_layouts is set to false. + HloModuleConfig() = default; - // Checks if this config has an entry computation layout already. - bool has_host_entry_computation_layout() const { - return host_entry_computation_layout_.has_value(); - } + explicit HloModuleConfig(const ProgramShape& program_shape, + bool ignore_layouts = true); - bool has_device_entry_computation_layout() const { - return device_entry_computation_layout_.has_value(); + // Checks if this config has an entry computation layout already. + bool has_entry_computation_layout() const { + return entry_computation_layout_.has_value(); } // Sets the entry computation layout for this config. If the entry computation // layout already exists, it is silently replaced. void SetDefaultComputationLayout(const ProgramShape& program_shape); - // Returns a constant reference to the on-host layout of the entry - // computation. Assumes the layout was set. - const ComputationLayout& host_entry_computation_layout() const { - CHECK(host_entry_computation_layout_.has_value()); - return *host_entry_computation_layout_; - } - - // Returns a mutable pointer to the layout of the on-host entry computation. + // Returns a constant reference to the layout of the entry computation. // Assumes the layout was set. - ComputationLayout* mutable_host_entry_computation_layout() { - CHECK(host_entry_computation_layout_.has_value()); - return &(*host_entry_computation_layout_); - } - - // Returns a constant reference to the on-device layout of the entry - // computation. Assumes the layout was set. - const ComputationLayout& device_entry_computation_layout() const { - CHECK(device_entry_computation_layout_.has_value()); - return *device_entry_computation_layout_; + const ComputationLayout& entry_computation_layout() const { + CHECK(entry_computation_layout_.has_value()); + return *entry_computation_layout_; } - // Returns a mutable pointer to the layout of the on-device entry computation. + // Returns a mutable pointer to the layout of the entry computation. // Assumes the layout was set. - ComputationLayout* mutable_device_entry_computation_layout() { - CHECK(device_entry_computation_layout_.has_value()); - return &(*device_entry_computation_layout_); + ComputationLayout* mutable_entry_computation_layout() { + CHECK(entry_computation_layout_.has_value()); + return &(*entry_computation_layout_); } // Returns whether to enable HLO-level profiling. @@ -127,8 +113,7 @@ class HloModuleConfig { private: // If you add new members, be sure to update compilation_cache_key. - tensorflow::gtl::optional host_entry_computation_layout_; - tensorflow::gtl::optional device_entry_computation_layout_; + tensorflow::gtl::optional entry_computation_layout_; // Whether this is a 'host module'. bool is_host_module_ = false; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index f6fa45a6b786492dbe59fea1376b932d06b9354a..bf33640db16638803f4f8e6c66f35d6bb6e2c9fe 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -113,6 +113,9 @@ Status HloModuleGroupMetadata::Build() { } } TF_RETURN_IF_ERROR(VerifyCompanionSets()); + if (VLOG_IS_ON(4)) { + DumpCollectedStats(); + } return Status::OK(); } @@ -124,9 +127,14 @@ Status HloModuleGroupMetadata::VerifyCompanionSets() const { for (HloInstruction* instruction : *companions) { // Go through all the communicating instructions (send, recv) of the given // companion, and record their device. + auto it = tracked_instructions_comms_.find(instruction); + if (it == tracked_instructions_comms_.end()) { + // Companions can be added even if they have no communicating + // instructions, if they are parent of companions. + continue; + } std::unordered_set comm_devices; - for (HloInstruction* comm_instruction : - tracked_instructions_comms_.at(instruction)) { + for (HloInstruction* comm_instruction : it->second) { auto device = GetInstructionDevice(*comm_instruction); TF_RET_CHECK(device) << "Instruction " << comm_instruction->ToString() << " does not have a device"; @@ -315,6 +323,7 @@ Status HloModuleGroupMetadata::RecordInstructions() { TF_RETURN_IF_ERROR(computation->Accept(visitor)); } } + VLOG(2) << "Created " << channels_.size() << " channels"; return Status::OK(); } @@ -445,4 +454,36 @@ Status HloModuleGroupMetadata::CheckCommunicatingInstruction( return FailedPrecondition("channel is used in disallowed computation"); } +void HloModuleGroupMetadata::DumpCollectedStats() const { + std::map, int64> communication_histogram; + for (auto& channel : channels_) { + auto from_device = GetInstructionDevice(*channel.send); + auto to_device = GetInstructionDevice(*channel.recv); + LOG(INFO) << "Channel " << channel.id << ": from_device=" << *from_device + << " to_device=" << *to_device << " send=" << channel.send->name() + << " send_done=" << channel.send_done->name() + << " recv=" << channel.recv->name() + << " recv_done=" << channel.recv_done->name(); + communication_histogram[std::pair(*from_device, + *to_device)] += 1; + } + for (auto& fromto_count : communication_histogram) { + LOG(INFO) << "From " << fromto_count.first.first << " to " + << fromto_count.first.second << ": " << fromto_count.second; + } + for (auto& companion_set : companion_sets_) { + LOG(INFO) << "Companion set:"; + for (HloInstruction* instruction : *companion_set) { + LOG(INFO) << " " << instruction->name(); + } + } + for (auto& instruction_comm : tracked_instructions_comms_) { + LOG(INFO) << "Communicating instruction " << instruction_comm.first->name(); + for (HloInstruction* instruction : instruction_comm.second) { + auto device = GetInstructionDevice(*instruction); + LOG(INFO) << " " << instruction->name() << " on device " << *device; + } + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index f68d4028dc33f1670df62b1fc4432fb1968bd255..ffde3a332dfc141ca928a44cfdf4686900e9f47b 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -230,6 +230,9 @@ class HloModuleGroupMetadata { return it != tracked_instructions_.end() ? &it->second : nullptr; } + // Dump all the collected module group statistics to the logs. + void DumpCollectedStats() const; + // List of all companion instructions sets in the module. std::vector>> companion_sets_; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index 5a0d1e264eb5095ff53721416ebcf4842a063f97..21a9b7291acc9e0066a9061facd13ab5acbf0bac 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -277,7 +277,7 @@ Status HloModuleGroupUtil::VerifyComputations( StatusOr> HloModuleGroupUtil::ComputeReachability( tensorflow::gtl::ArraySlice computations) { - std::list post_order; + std::vector post_order; auto visit_function = [&](HloInstruction* instruction, const std::vector& instruction_group) { diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 1fe06ee0c0d14255b8358fb998bfd8d0b029506f..a35546f5f41b149d119ee141fd734da8bfd055b2 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -81,6 +81,7 @@ namespace xla { V(kFusion, "fusion", kHloOpcodeIsVariadic) \ V(kGather, "gather") \ V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ + V(kGenerateToken, "generate-token", kHloOpcodeIsVariadic) \ V(kGetTupleElement, "get-tuple-element") \ V(kGt, "greater-than", kHloOpcodeIsComparison) \ V(kHostCompute, "host-compute") \ diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc index cd2ce5c69f030c65b889d67e082a3677b8739ddb..774345124b4ad62e35d9423a23f1dbaa28e44d80 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc @@ -58,6 +58,7 @@ TEST(HloOpcodeTest, OpcodeProperties) { case HloOpcode::kConcatenate: case HloOpcode::kFusion: case HloOpcode::kMap: + case HloOpcode::kGenerateToken: case HloOpcode::kTuple: EXPECT_TRUE(HloOpcodeIsVariadic(opcode)); break; diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index dcd4725fe78e8b9b5d14437e964cb5aaf1664117..6c1e015f77a62c3e3ff7ffa5ce9dea735f46e10a 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -232,6 +232,11 @@ bool HloOrdering::UseIsBeforeValueDefinition( << " and def is in FALSE computation"; return true; } + if (value.defining_instruction() == use.instruction) { + VLOG(4) << " use is conditional " << use << " and def is " + << value.ToShortString(); + return true; + } } VLOG(4) << " use is not before value"; diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index ee526d8dd7f7e81b3a846741d3e452935f486bd2..985f3fa64d8767b0c0063ee900f7d11c3b7f6d4a 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -183,6 +183,10 @@ class DependencyHloOrdering : public PredecessorHloOrdering { // interference is reduced relative to DependencyHloOrdering. class SequentialHloOrdering : public HloOrdering { public: + // TODO(dimvar): HloModuleSequence is not a good name because it sounds like + // a sequence of modules, instead of a map of schedules for all computations + // in a module. We should change it at some point. + // // A sequence of instructions for each computation in the module. using HloModuleSequence = tensorflow::gtl::FlatMapnum_parameters(); p++) { const Shape& param_shape = computation->parameter_instruction(p)->shape(); - TF_CHECK_OK(module_->mutable_host_entry_computation_layout() - ->mutable_parameter_layout(p) - ->CopyLayoutFromShape(param_shape)); - TF_CHECK_OK(module_->mutable_device_entry_computation_layout() + TF_CHECK_OK(module_->mutable_entry_computation_layout() ->mutable_parameter_layout(p) ->CopyLayoutFromShape(param_shape)); } const Shape& result_shape = computation->root_instruction()->shape(); - TF_CHECK_OK(module_->mutable_host_entry_computation_layout() - ->mutable_result_layout() - ->CopyLayoutFromShape(result_shape)); - TF_CHECK_OK(module_->mutable_device_entry_computation_layout() + TF_CHECK_OK(module_->mutable_entry_computation_layout() ->mutable_result_layout() ->CopyLayoutFromShape(result_shape)); } - return true; } @@ -587,11 +580,31 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kCrossReplicaSum: { + optional to_apply; + optional> replica_group_ids; + optional barrier; + optional all_reduce_id; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &to_apply}; + attrs["replica_group_ids"] = { + /*required=*/false, AttrTy::kBracedInt64List, &replica_group_ids}; + attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier}; + attrs["all_reduce_id"] = {/*required=*/false, AttrTy::kInt64, + &all_reduce_id}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction( - HloInstruction::CreateCrossReplicaSum(shape, operands)); + if (replica_group_ids) { + instruction = + builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( + shape, operands, *to_apply, *replica_group_ids, + barrier ? *barrier : "", all_reduce_id)); + } else { + instruction = + builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( + shape, operands, *to_apply, {}, barrier ? *barrier : "", + all_reduce_id)); + } break; } case HloOpcode::kReshape: { @@ -603,6 +616,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction::CreateReshape(shape, operands[0])); break; } + case HloOpcode::kGenerateToken: { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateGenerateToken(operands)); + break; + } case HloOpcode::kTuple: { if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; @@ -774,6 +795,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional to_apply; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; + optional> dimensions; + attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List, + &dimensions}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } @@ -1134,7 +1158,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloOpcodeString(opcode))); } - instruction->set_name(name); + instruction->SetAndSanitizeName(name); + if (instruction->name() != name) { + return Error(name_loc, + StrCat("illegal instruction name: ", name, + "; suggest renaming to: ", instruction->name())); + } // Add shared attributes like metadata to the instruction, if they were seen. if (sharding) { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 84a981675f6c2fc5e0e1c2f103698703ca79a716..d481e07f60a0747ae5bd6217aaeeb25d6fe733e1 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -765,7 +765,7 @@ add_F32.v3 { ENTRY MapBinaryAdder.v3 { param0 = f32[4]{0} parameter(0) param1 = f32[4]{0} parameter(1) - ROOT map = f32[4]{0} map(param0, param1), to_apply=add_F32.v3 + ROOT map = f32[4]{0} map(param0, param1), dimensions={0}, to_apply=add_F32.v3 } )" @@ -900,6 +900,42 @@ ENTRY Gather { )" }, +// cross-replica-sum +{ +"CrossReplicaSum", +R"(HloModule CRS + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY CRS { + input = f32[8]{0} parameter(0) + ROOT crs = f32[8]{0} cross-replica-sum(input), replica_group_ids={}, to_apply=add +} + +)" +}, +// cross-replica-sum with subgroups +{ +"CrossReplicaSumWithSubgroups", +R"(HloModule CRS_Subgroups + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY CrossReplicaSumWithSubgroups { + input = f32[128,32]{0,1} parameter(0) + ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_group_ids={0,0,1,1}, barrier="abc", to_apply=add +} + +)" +} }); // clang-format on } @@ -1266,7 +1302,7 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { auto module = ParseHloString(original); TF_ASSERT_OK(module.status()); - auto program_layout = module.ValueOrDie()->host_entry_computation_layout(); + auto program_layout = module.ValueOrDie()->entry_computation_layout(); ASSERT_EQ(program_layout.parameter_count(), 1); auto param_layout = program_layout.parameter_layout(0).layout(); auto result_layout = program_layout.result_layout().layout(); diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index d45038f1f4a2e4aa19234eec93fdc9a068a902e1..2418c19f3de7b036d7ef52d3a6db11de6316203b 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -61,7 +61,7 @@ bool AllOperandsAreConstants(const HloInstruction& instruction) { } HloInstruction* GetMatchingOperand( - std::function matcher, + const std::function& matcher, HloInstruction* instruction) { for (HloInstruction* op : instruction->operands()) { if (matcher(op)) { @@ -72,7 +72,7 @@ HloInstruction* GetMatchingOperand( } bool MatchBinaryInstructionOperand( - std::function matcher, + const std::function& matcher, HloInstruction* instruction, HloInstruction** matching_operand, HloInstruction** other_operand) { CHECK_EQ(instruction->operand_count(), 2); diff --git a/tensorflow/compiler/xla/service/hlo_query.h b/tensorflow/compiler/xla/service/hlo_query.h index c79347bbf9d6146943b7b787f713369cb37fadee..c0826a6aee1f693484207a86ec258c6604d92318 100644 --- a/tensorflow/compiler/xla/service/hlo_query.h +++ b/tensorflow/compiler/xla/service/hlo_query.h @@ -45,7 +45,7 @@ bool IsScalarConstant(const HloInstruction* instruction); // multiple matching operands, then the first matching operand is returned. If // there are no matching operands then nullptr is returned. HloInstruction* GetMatchingOperand( - std::function matcher, + const std::function& matcher, HloInstruction* instruction); // Returns whether a binary instruction has a matching operand. Sets @@ -53,7 +53,7 @@ HloInstruction* GetMatchingOperand( // other_operand. Note: in the case where both operands match, the first operand // of the instruction is returned. bool MatchBinaryInstructionOperand( - std::function matcher, + const std::function& matcher, HloInstruction* instruction, HloInstruction** matching_operand, HloInstruction** other_operand); diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 4738e46f8aeb96a4c25d04b3246bd21f644fe3ea..01b088a957554821e65db7bf9cedf334db49728f 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -18,7 +18,7 @@ limitations under the License. namespace xla { HloReachabilityMap::HloReachabilityMap( - const std::list& instructions) + tensorflow::gtl::ArraySlice instructions) : size_(instructions.size()) { bit_vectors_.reserve(size_); for (const HloInstruction* hlo : instructions) { diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index 69bb2b3cee6dafe058c45b4e74e93401bea2cfc9..48215d32a8284919cce6beb1663e6a723eefc1c4 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -41,7 +41,8 @@ class HloReachabilityMap { public: // Sets up a graph with no edges and where the nodes correspond to the given // instructions. - explicit HloReachabilityMap(const std::list& instructions); + explicit HloReachabilityMap( + tensorflow::gtl::ArraySlice instructions); // Set the reachability set of 'instruction' to the union of the reachability // sets of 'inputs'. Upon return, IsReachable(x, instruction) where diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 39b85de0f12024f5e20ddd37618987c6d06bc307..62c07d7fac93618a83b3b6111aec1e93309a0761 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" +#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -71,6 +72,20 @@ bool IsRematerializable(const HloInstruction* instruction) { } } +// Checks whether an instruction can be rematerialized, by looking up the +// cache before, and eventually calling the IsRematerializable() API. +bool CanBeRematerialized( + const HloInstruction* instruction, + tensorflow::gtl::FlatMap* remat_able) { + auto it = remat_able->find(instruction); + if (it != remat_able->end()) { + return it->second; + } + bool rematerializable = IsRematerializable(instruction); + (*remat_able)[instruction] = rematerializable; + return rematerializable; +} + // Type holding a unique identifier for each Buffer object. using BufferId = int64; using BufferIdList = tensorflow::gtl::InlinedVector; @@ -843,9 +858,10 @@ int64 RematerializationCost(const HloInstruction* instruction, // candidate which reduce memory use at the program point of the current // instruction as indicated by memory_tracker. nullptr is returned if no // candidate can be found. -Item* PickRematerializationCandidate(const MemoryUsageTracker& memory_tracker, - const InstructionList& instruction_list, - int64 memory_limit_bytes) { +Item* PickRematerializationCandidate( + const MemoryUsageTracker& memory_tracker, + const InstructionList& instruction_list, int64 memory_limit_bytes, + tensorflow::gtl::FlatMap* remat_able) { Item* best_item = nullptr; int64 best_cost = 0; @@ -869,8 +885,7 @@ Item* PickRematerializationCandidate(const MemoryUsageTracker& memory_tracker, << " is excluded from rematerialization"; continue; } - - if (!IsRematerializable(candidate)) { + if (!CanBeRematerialized(candidate, remat_able)) { VLOG(5) << "candidate " << candidate->name() << " not viable: is not rematerializable"; continue; @@ -974,6 +989,9 @@ StatusOr HloRematerialization::RematerializeComputation( // blacklist. tensorflow::gtl::FlatSet remat_move_instructions; + // The map from instructions to their rematerializable status. + tensorflow::gtl::FlatMap remat_able; + // The peak memory of the computation at any point in the instruction // sequence. int64 peak_memory = memory_tracker.memory_usage(); @@ -1011,7 +1029,7 @@ StatusOr HloRematerialization::RematerializeComputation( << ", limit is " << HumanReadableNumBytes(memory_limit_bytes); Item* best_item = PickRematerializationCandidate( - memory_tracker, instruction_list, memory_limit_bytes); + memory_tracker, instruction_list, memory_limit_bytes, &remat_able); if (best_item == nullptr) { VLOG(3) << "Unable to find rematerialization candidate at program " @@ -1184,7 +1202,8 @@ StatusOr HloRematerialization::RematerializeComputation( StatusOr HloRematerialization::Run( HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence, - int64 memory_limit_bytes, RematerializationSizes* sizes) { + int64 memory_limit_bytes, RematerializationSizes* sizes, + bool run_copy_elision) { // The sequence is constructed entirely by this method. TF_RET_CHECK(sequence->empty()); @@ -1213,12 +1232,21 @@ StatusOr HloRematerialization::Run( XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); // Create initial sequence of HLO instructions. - TF_ASSIGN_OR_RETURN(*sequence, CreateMemoryMinimizingSequence( + TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule( *module, [this](const BufferValue& buffer) { return size_function_(buffer.shape()); }, scheduler_algorithm_)); + if (run_copy_elision) { + // We run a separate pass of copy elision here because the sequential + // ordering from the HLO schedule allows for more copies to be eliminated. + // TODO(b/80249101): Instead of a separate copy elision pass, use the + // ordering from the HLO schedule directly for copy insertion. + SequentialHloOrdering ordering(module, *sequence); + TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, {}, module)); + } + // Compute peak memory usage of all computations in the module called in a // sequential context. call_graph_ = CallGraph::Build(module); @@ -1321,9 +1349,10 @@ StatusOr HloRematerialization::Run( int64 memory_limit_bytes, HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, - RematerializationSizes* sizes) { + RematerializationSizes* sizes, bool run_copy_elision) { HloRematerialization remat(scheduler_algorithm, size_function); - return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes); + return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes, + run_copy_elision); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 2ee2dd0571ae8c6604e4ca722351fd48a913bda5..59b4cf5dcc761f70767ce4d7ff0959448f29939a 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -57,6 +57,12 @@ class HloRematerialization { // sizes: Optional outparam that indicates the peak memory usage of the HLO // module before/after rematerialization. // + // run_copy_elision: Enable copy elision. This pass is used to eliminate + // copies that were inserted before HLO scheduling. + // + // TODO(b/80249101): Remove the 'run_copy_elision' parameter when copy + // insertion is integrated with HLO scheduling. + // // Returns whether any instructions were rematerialized. If memory use is // already below the given limit then no instructions are rematerialized and // false is returned. @@ -68,7 +74,7 @@ class HloRematerialization { const ShapeSizeFunction& size_function, int64 memory_limit_bytes, HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, - RematerializationSizes* sizes = nullptr); + RematerializationSizes* sizes, bool run_copy_elision = true); protected: HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm, @@ -83,7 +89,8 @@ class HloRematerialization { // contains the memory-minimizing order in which to emit the HLO instructions. StatusOr Run(HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence, - int64 memory_limit, RematerializationSizes* sizes); + int64 memory_limit, RematerializationSizes* sizes, + bool run_copy_elision); // Rematerializes instructions within the given computation. 'order' is the // order in which the computation's instructions will be emitted in the diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 83de54f3fa56ee660b79d8c366dbc0b52f9fde87..7a46da6efe0df23129d56e16355cf66aceb68ffe 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -40,7 +41,8 @@ class HloRematerializationTest : public HloTestBase { // Creates and returns a computation which can benefit from // rematerialization. The computation looks like: // - // F32[] %param = {...} + // F32[1] %param = {...} + // F32[] %reshape = reshape(F32[], param) // F32[1024] %bcast = broadcast(%param) // F32[1024] %negate = negate(%bcast) // F32[2048] %concat_1 = concat({%negate, %negate}) @@ -57,9 +59,11 @@ class HloRematerializationTest : public HloTestBase { const string& suffix = "") { auto builder = HloComputation::Builder(TestName() + suffix); auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape_, "param")); + HloInstruction::CreateParameter(0, vec1_shape_, "param")); + auto reshape = builder.AddInstruction( + HloInstruction::CreateReshape(scalar_shape_, param)); auto bcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); + HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {})); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, bcast)); auto concat_1 = builder.AddInstruction(HloInstruction::CreateConcatenate( @@ -100,9 +104,11 @@ class HloRematerializationTest : public HloTestBase { const string& suffix = "") { auto builder = HloComputation::Builder(TestName() + suffix); auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape_, "param")); + HloInstruction::CreateParameter(0, vec1_shape_, "param")); + auto reshape = builder.AddInstruction( + HloInstruction::CreateReshape(scalar_shape_, param)); auto bcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); + HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {})); auto slice_1 = builder.AddInstruction( HloInstruction::CreateSlice(vec1_shape_, bcast, /*start_indices=*/{0}, /*limit_indices=*/{1}, @@ -135,6 +141,15 @@ class HloRematerializationTest : public HloTestBase { return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); } + StatusOr RunHloRematerialization( + int64 memory_limit_bytes, HloModule* module, + SequentialHloOrdering::HloModuleSequence* sequence) { + TF_EXPECT_OK(verifier().Run(module).status()); + return HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler, + sequence, /*sizes=*/nullptr, /*run_copy_elision=*/false); + } + // Various shapes used in the canned computations. const Shape scalar_shape_ = ShapeUtil::MakeShape(xla::F32, {}); const Shape vec1_shape_ = ShapeUtil::MakeShape(xla::F32, {1}); @@ -158,11 +173,9 @@ TEST_F(HloRematerializationTest, SingleComputation) { SequentialHloOrdering::HloModuleSequence sequence; // Computation requires 16KB without rematerialization, but uses only 12KB // with rematerialization so pick a memory limit between these values (14KB). - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/14 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/14 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); // Root should not have changed. @@ -188,18 +201,16 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { HloComputation* computation = module->AddEntryComputation(MakeRematerializableComputation()); - EXPECT_EQ(computation->instruction_count(), 7); + EXPECT_EQ(computation->instruction_count(), 8); SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/20 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/20 * 1024, + module.get(), &sequence)); // No instructions should have been materialized. EXPECT_FALSE(changed); - EXPECT_EQ(computation->instruction_count(), 7); + EXPECT_EQ(computation->instruction_count(), 8); } // Test rematerialization of a computation which calls another computation via a @@ -225,23 +236,21 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { module->AddEntryComputation(MakeRematerializableWhileComputation( while_cond, /*while_body=*/body_computation)); - EXPECT_EQ(entry_computation->instruction_count(), 6); - EXPECT_EQ(body_computation->instruction_count(), 7); + EXPECT_EQ(entry_computation->instruction_count(), 7); + EXPECT_EQ(body_computation->instruction_count(), 8); // The body computation uses 16KB and the entry computation uses 2KB at the // while so the peak memory use of the module is 18KB. Set the memory limit a // bit lower (17KB) to force rematerialization of the entry computation. SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/17 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/17 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. - EXPECT_EQ(entry_computation->instruction_count(), 7); - EXPECT_EQ(body_computation->instruction_count(), 7); + EXPECT_EQ(entry_computation->instruction_count(), 8); + EXPECT_EQ(body_computation->instruction_count(), 8); } // Test rematerialization of a computation which calls another computation via a @@ -264,20 +273,18 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { module->AddEntryComputation(MakeRematerializableWhileComputation( while_cond, /*while_body=*/body_computation)); - EXPECT_EQ(entry_computation->instruction_count(), 6); - EXPECT_EQ(body_computation->instruction_count(), 7); + EXPECT_EQ(entry_computation->instruction_count(), 7); + EXPECT_EQ(body_computation->instruction_count(), 8); SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/15 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/15 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); - // Both computations should have a rematerialized instruction added. - EXPECT_EQ(entry_computation->instruction_count(), 7); - EXPECT_EQ(body_computation->instruction_count(), 8); + // Both computations should have rematerialized instructions added. + EXPECT_EQ(entry_computation->instruction_count(), 9); + EXPECT_EQ(body_computation->instruction_count(), 9); } // Test rematerialization of a doubly nested computation. All computations @@ -303,24 +310,22 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { module->AddEntryComputation(MakeRematerializableWhileComputation( while_cond, /*while_body=*/middle_computation)); - EXPECT_EQ(entry_computation->instruction_count(), 6); - EXPECT_EQ(middle_computation->instruction_count(), 6); - EXPECT_EQ(inner_computation->instruction_count(), 7); + EXPECT_EQ(entry_computation->instruction_count(), 7); + EXPECT_EQ(middle_computation->instruction_count(), 7); + EXPECT_EQ(inner_computation->instruction_count(), 8); // If all computations are maximally rematerialized then peak memory usage is // ~12K so pick something slightly larger. SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/13 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/13 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); - // All computations should have a rematerialized instruction added. - EXPECT_EQ(entry_computation->instruction_count(), 7); - EXPECT_EQ(middle_computation->instruction_count(), 7); - EXPECT_EQ(inner_computation->instruction_count(), 8); + // All computations should have rematerialized instructions added. + EXPECT_EQ(entry_computation->instruction_count(), 9); + EXPECT_EQ(middle_computation->instruction_count(), 9); + EXPECT_EQ(inner_computation->instruction_count(), 9); } TEST_F(HloRematerializationTest, RngNotRematerialized) { @@ -382,10 +387,9 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN( - bool changed, HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, + bool changed, RunHloRematerialization( /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), - module.get(), DefaultMemoryScheduler, &sequence)); + module.get(), &sequence)); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -476,11 +480,9 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/22 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -573,11 +575,9 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/22 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, + module.get(), &sequence)); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index e1f9d8efd4974055947438c8a2e15cb77d1b5c75..4f0569f4059481aa19da8c7854fedf0e43182e36 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -98,8 +98,10 @@ StatusOr HloRunner::TransferLiteralToDevice( backend().transfer_manager()->AllocateScopedShapedBuffer( literal.shape(), backend().memory_allocator(), backend().default_device_ordinal())); + TF_ASSIGN_OR_RETURN( + auto stream, backend().BorrowStream(backend().default_stream_executor())); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - backend().default_stream_executor(), literal, buffer)); + stream.get(), literal, buffer)); return std::move(buffer); } @@ -127,8 +129,10 @@ StatusOr> HloRunner::TransferLiteralsToDevice( StatusOr> HloRunner::TransferLiteralFromDevice( const ShapedBuffer& buffer) { - return backend().transfer_manager()->TransferLiteralFromDevice( - backend().default_stream_executor(), buffer); + TF_ASSIGN_OR_RETURN( + auto stream, backend().BorrowStream(backend().default_stream_executor())); + return backend().transfer_manager()->TransferLiteralFromDevice(stream.get(), + buffer); } StatusOr> HloRunner::Execute( @@ -237,7 +241,7 @@ StatusOr>> HloRunner::ExecuteReplicated( backend().transfer_manager()->AllocateScopedShapedBuffer( argument->shape(), backend().memory_allocator(), device)); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - executor, *argument, argument_buffer)); + streams.back().get(), *argument, argument_buffer)); argument_buffers.push_back(std::move(argument_buffer)); argument_buffer_ptrs[index++] = &argument_buffers.back(); } @@ -307,7 +311,7 @@ StatusOr>> HloRunner::ExecuteReplicated( for (int64 i = 0; i < options.num_replicas; ++i) { TF_ASSIGN_OR_RETURN(std::unique_ptr literal, backend().transfer_manager()->TransferLiteralFromDevice( - streams[i]->parent(), results[i])); + streams[i].get(), results[i])); exec_results.push_back(std::move(literal)); } return std::move(exec_results); diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 68b2cde83a2eb479d9ba71fc6eab9ac9ab1c8267..c6d3909af6103949daf4b0ab6be9b74724461e30 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -36,29 +36,6 @@ using ::tensorflow::strings::HumanReadableNumBytes; namespace xla { -StatusOr MinimumMemoryForSequence( - const SequentialHloOrdering::HloModuleSequence& module_sequence, - const LogicalBuffer::SizeFunction& size_function) { - if (module_sequence.empty()) { - return 0; - } - - const HloModule* module = module_sequence.begin()->first->parent(); - TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, - TuplePointsToAnalysis::Run(module)); - - // The absolute minimum memory required for a given sequence of instructions - // is determined by the sequence of Alloc and Free calls on a simulated heap, - // ignoring fragmentation. We run the heap simulation on the whole module, - // rather than summing each computation, since it gives us a better lower - // bound, by minimizing the liveness of sub-computations. - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), *module, - module_sequence, *points_to_analysis, size_function)); - return result.heap_size; -} - namespace { // Class implementing a list scheduler of HLO instructions which produces a @@ -398,7 +375,7 @@ int64 SumLogicalBufferSizes( return size; } -StatusOr> CreateMemoryMinimizingSequence( +StatusOr> ScheduleComputationHelper( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -416,30 +393,15 @@ StatusOr> CreateMemoryMinimizingSequence( } // namespace -StatusOr MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector& sequence, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), computation, - sequence, points_to_analysis, size_function)); - return result.heap_size; -} - StatusOr> DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap& memory_by_computation) { - // This ordering is based on DFS post-order, with a heuristic to decide which - // operand to visit first. The heuristic is based on 'extra_users', which is - // simply users-1 for each instruction. By subtracting 1, we're saying that - // instructions with no users or a single user don't count; instructions with - // lots of fan-out will be visited earlier. + // These variables are a hack to prevent overflows. int64 cumulative_total_size = 0; + int64 total_hlos = computation.parent()->NumUniqueInstructionIds(); tensorflow::gtl::FlatMap extra_users; tensorflow::gtl::FlatMap total_sizes; for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { @@ -448,6 +410,11 @@ StatusOr> DFSMemoryScheduler( total_sizes[hlo] = 0; continue; } + // This ordering is based on DFS post-order, with a heuristic to decide + // which operand to visit first. The heuristic is based on 'extra_users', + // which is simply users-1 for each instruction. By subtracting 1, we're + // saying that instructions with no users or a single user don't count; + // instructions with lots of fan-out will be visited earlier. extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1; int64 logical_buffer_size = SumLogicalBufferSizes( points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); @@ -463,10 +430,13 @@ StatusOr> DFSMemoryScheduler( // lead to it. But computation is a DAG, so we are double-counting nodes, // which can lead to overflows for large programs. // cumulative_total_size caps the size to prevent overflows. + // Same for total_hlos: it prevents overflows on very large and branchy + // models, where the number of paths is exponential to the number of nodes. // NOTE(dimvar): this is quite ugly and should be changed. It's unclear // why we care about transitive sizes; when scheduling a node, its input // and output buffers should be all that matters, not its "history". total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size); + extra_users[hlo] = std::min(extra_users[hlo], total_hlos); } CHECK_EQ(extra_users.size(), computation.instruction_count()); CHECK_EQ(total_sizes.size(), computation.instruction_count()); @@ -533,29 +503,29 @@ StatusOr> DefaultMemoryScheduler( std::vector list_sequence, ListMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); - TF_ASSIGN_OR_RETURN( - const int64 list_memory, - MinimumMemoryForComputation(computation, list_sequence, - points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN(const int64 list_memory, + HeapSimulator::MinimumMemoryForComputation( + computation, list_sequence, points_to_analysis, + size_function, &memory_by_computation)); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); TF_ASSIGN_OR_RETURN(std::vector dfs_sequence, DFSMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); - TF_ASSIGN_OR_RETURN( - const int64 dfs_memory, - MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, - size_function)); + TF_ASSIGN_OR_RETURN(const int64 dfs_memory, + HeapSimulator::MinimumMemoryForComputation( + computation, dfs_sequence, points_to_analysis, + size_function, &memory_by_computation)); VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); TF_ASSIGN_OR_RETURN( std::vector post_order_sequence, PostOrderMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); - TF_ASSIGN_OR_RETURN( - const int64 post_order_memory, - MinimumMemoryForComputation(computation, post_order_sequence, - points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN(const int64 post_order_memory, + HeapSimulator::MinimumMemoryForComputation( + computation, post_order_sequence, points_to_analysis, + size_function, &memory_by_computation)); VLOG(2) << "Min-memory post order sequence: " << HumanReadableNumBytes(post_order_memory); @@ -576,10 +546,9 @@ StatusOr> DefaultMemoryScheduler( } } -StatusOr -CreateMemoryMinimizingSequence(const HloModule& module, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm) { +StatusOr ScheduleComputationsInModule( + const HloModule& module, const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm) { SequentialHloOrdering::HloModuleSequence sequence; TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); @@ -587,12 +556,13 @@ CreateMemoryMinimizingSequence(const HloModule& module, for (const auto* computation : module.MakeComputationPostOrder()) { if (!computation->IsFusionComputation()) { TF_ASSIGN_OR_RETURN(auto one_computation_sequence, - CreateMemoryMinimizingSequence( + ScheduleComputationHelper( *computation, *points_to_analysis, size_function, algorithm, memory_by_computation)); memory_by_computation[computation] = - MinimumMemoryForComputation(*computation, one_computation_sequence, - *points_to_analysis, size_function) + HeapSimulator::MinimumMemoryForComputation( + *computation, one_computation_sequence, *points_to_analysis, + size_function, &memory_by_computation) .ValueOrDie(); sequence[computation] = std::move(one_computation_sequence); } @@ -600,15 +570,15 @@ CreateMemoryMinimizingSequence(const HloModule& module, return sequence; } -StatusOr> CreateMemoryMinimizingSequence( +StatusOr> ScheduleOneComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function) { CHECK(!computation.IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(computation.parent())); tensorflow::gtl::FlatMap empty_map; - return CreateMemoryMinimizingSequence(computation, *points_to_analysis, - size_function, nullptr, empty_map); + return ScheduleComputationHelper(computation, *points_to_analysis, + size_function, nullptr, empty_map); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h index 49b927eefd24f4e26df781dd8d2b977bedba2b80..2b33ccc8bfb895286bb3747aab0a16cf25e2cfae 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -28,20 +28,6 @@ limitations under the License. namespace xla { -// Returns the minimum memory required to compute the given module sequence, -// assuming no fragmentation. -StatusOr MinimumMemoryForSequence( - const SequentialHloOrdering::HloModuleSequence& module_sequence, - const LogicalBuffer::SizeFunction& size_function); - -// Returns the minimum memory required to compute the given computation, -// assuming no fragmentation. -StatusOr MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector& sequence, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function); - // A memory scheduler computes an execution sequence for the HLO instructions in // 'computation' that minimizes peak memory, given a points-to analysis result // that describes buffer aliasing, together with a target-specific size function @@ -89,14 +75,13 @@ StatusOr> DefaultMemoryScheduler( // Returns an HloModuleSequence which seeks to minimize the memory required for // the computation. size_function is the function returning the number of bytes // required for a LogicalBuffer. -StatusOr -CreateMemoryMinimizingSequence(const HloModule& module, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm = {}); +StatusOr ScheduleComputationsInModule( + const HloModule& module, const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm = {}); -// Overload of above that computes the sequence for a single computation. +// Computes the schedule for a single computation. // Currently only used by the GPU backend. -StatusOr> CreateMemoryMinimizingSequence( +StatusOr> ScheduleOneComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function); diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index db7ef6f0d4bd96216ea07ccc75a51513822bf2e3..73f22f81f4e9cf597db8b184642acff2fdaaf2b0 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -31,65 +32,6 @@ limitations under the License. namespace xla { namespace { -class MinimumMemoryForSequenceTest : public HloTestBase {}; - -TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { - auto module = CreateNewModule(); - const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); - const Shape tuple_shape = - ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); - - auto cond_builder = HloComputation::Builder("WhileCond"); - // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) - HloInstruction* cond_param = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); - HloInstruction* cond_iter = cond_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); - HloInstruction* cond_data = cond_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); - // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) - HloInstruction* cond_lt = cond_builder.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), - HloOpcode::kLt, cond_iter, cond_data)); - HloComputation* cond_computation = - module->AddEmbeddedComputation(cond_builder.Build()); - - auto body_builder = HloComputation::Builder("WhileBody"); - // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) - HloInstruction* body_param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "body_param")); - HloComputation* body_computation = - module->AddEmbeddedComputation(body_builder.Build()); - - auto builder = HloComputation::Builder(TestName()); - // Entry params: 8 bytes (4 bytes per param), TOTAL=8 - HloInstruction* iter = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "param_iter")); - HloInstruction* data = builder.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "param_data")); - // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24 - HloInstruction* tuple = - builder.AddInstruction(HloInstruction::CreateTuple({iter, data})); - // While: 8 bytes (4 bytes per element), TOTAL=32 - // Both cond and body use a max of 24 bytes, TOTAL=56 - HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( - tuple_shape, cond_computation, body_computation, tuple)); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build()); - - auto size_fn = [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); - }; - - SequentialHloOrdering::HloModuleSequence module_sequence; - module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, - cond_lt}; - module_sequence[body_computation] = {body_param}; - module_sequence[entry_computation] = {iter, data, tuple, while_op}; - EXPECT_EQ(56, - MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie()); -} - class HloSchedulingTest : public HloTestBase {}; TEST_F(HloSchedulingTest, LastUseScheduledFirst) { @@ -124,7 +66,7 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { TF_ASSERT_OK_AND_ASSIGN( SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence(*module, [](const BufferValue& buffer) { + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); // Verify that all instructions are in the sequence. @@ -165,7 +107,7 @@ ENTRY root { }; TF_ASSERT_OK_AND_ASSIGN( SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence(*module, size_fn, ListMemoryScheduler)); + ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.at(module->entry_computation()).size()); @@ -203,7 +145,7 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { // ROOT %subtract = f32[4]{0} subtract( // f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1) // } - // %SubcomputationsNotAccounted () -> f32[2,4] { + // %ListAccountsForSubcomputations () -> f32[2,4] { // %constant.3 = f32[2,4]{1,0} constant( // f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } }) // %transpose = f32[2,4]{1,0} transpose( @@ -269,16 +211,16 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence( - *module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - }, - ListMemoryScheduler)); + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }; + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); + auto entry_computation = module->entry_computation(); + EXPECT_EQ(entry_computation->instruction_count(), + sequence.at(entry_computation).size()); SequentialHloOrdering ordering(module.get(), sequence); // This schedule is an example of List's greedy heuristics being suboptimal. // The while_loop is more expensive than transpose, so it would have been @@ -287,6 +229,24 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { EXPECT_TRUE(ordering.ExecutesBefore(transpose, bcast)); EXPECT_TRUE(ordering.ExecutesBefore(bcast, add)); EXPECT_TRUE(ordering.ExecutesBefore(transpose, add)); + + tensorflow::gtl::FlatMap memory_by_computation; + memory_by_computation[cond_computation] = 17; + memory_by_computation[body_computation] = 16; + std::unique_ptr points_to_analysis = + TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); + + // HeapSimulator doesn't account for subcomputations + EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, sequence.at(entry_computation), + *points_to_analysis, size_fn) + .ValueOrDie()); + // HeapSimulator accounts for subcomputations. The max mem doesn't change + // because the while body isn't live during the peak. + EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, sequence.at(entry_computation), + *points_to_analysis, size_fn, &memory_by_computation) + .ValueOrDie()); } TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { @@ -318,12 +278,12 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { module->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN( SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence(*module, - [&TUPLE_SIZE](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf( - buffer.shape(), TUPLE_SIZE); - }, - ListMemoryScheduler)); + ScheduleComputationsInModule(*module, + [&TUPLE_SIZE](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), TUPLE_SIZE); + }, + ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), @@ -368,7 +328,7 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { {tuple, mul, add}, HloInstruction::FusionKind::kLoop); TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence( + ScheduleComputationsInModule( *module, [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), 2); @@ -384,5 +344,70 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion)); } +TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { + auto module = CreateNewModule(); + const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); + const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); + + // param != 0 + // Needs 17 bytes + auto cond_builder = HloComputation::Builder("WhileCond"); + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "cond_param")); + HloInstruction* zero_vector = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({{0, 0, 0, 0}}))); + cond_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); + auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); + + // param - 1 + // Needs 16 bytes + auto body_builder = HloComputation::Builder("WhileBody"); + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "body_param")); + HloInstruction* one_vector = body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + body_builder.AddInstruction(HloInstruction::CreateBinary( + r1f32, HloOpcode::kSubtract, body_param, one_vector)); + auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* while_init = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + // Creates 16 bytes, ignoring subcomputations + builder.AddInstruction(HloInstruction::CreateWhile( + r1f32, cond_computation, body_computation, while_init)); + + module->AddEntryComputation(builder.Build()); + + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }; + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); + // Verify that all instructions are in the sequence. + auto entry_computation = module->entry_computation(); + EXPECT_EQ(entry_computation->instruction_count(), + sequence.at(entry_computation).size()); + + tensorflow::gtl::FlatMap memory_by_computation; + memory_by_computation[cond_computation] = 17; + memory_by_computation[body_computation] = 16; + std::unique_ptr points_to_analysis = + TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); + + // HeapSimulator doesn't account for subcomputations + EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, sequence.at(entry_computation), + *points_to_analysis, size_fn) + .ValueOrDie()); + // HeapSimulator accounts for subcomputations + EXPECT_EQ(33, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, sequence.at(entry_computation), + *points_to_analysis, size_fn, &memory_by_computation) + .ValueOrDie()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 58224ef870096a774d5892b9aa12c38f5ff511bd..268b4727bcbed42ba71526f1d5ef5c887e941930 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -39,6 +39,34 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) { return HloSharding(tile_shape, assignment); } +HloSharding HloSharding::Tuple(const ShapeTree& sub_shardings) { + std::vector flattened_list; + flattened_list.reserve(sub_shardings.leaf_count()); + for (const auto& index_to_sharding : sub_shardings.leaves()) { + flattened_list.push_back(index_to_sharding.second); + } + if (flattened_list.empty()) { + // Empty tuple sharding ends up having no leaves, but we want to allow + // empty tuple HLO instruction results to have sharding, so we fetch the + // root ({}) sharding value from the ShapeTree. + // A ShapeTree created with ShapeTree(shape, init) will have + // init as value at its root. + flattened_list.push_back(sub_shardings.element(ShapeIndex({}))); + } + return HloSharding(flattened_list); +} + +HloSharding HloSharding::Tuple( + const Shape& tuple_shape, + tensorflow::gtl::ArraySlice shardings) { + CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); + std::vector flattened_list(shardings.begin(), shardings.end()); + CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape)) + << "Flat list has " << flattened_list.size() << ", required " + << RequiredLeaves(tuple_shape); + return HloSharding(flattened_list); +} + string HloSharding::ToString() const { if (IsTuple()) { std::vector parts; @@ -72,6 +100,29 @@ bool HloSharding::UsesDevice(int64 device) const { std::find(devices.begin(), devices.end(), device) != devices.end(); } +std::map HloSharding::UsedDevices(int64* count) const { + int64 element_count = 1; + std::map device_map; + if (IsTuple()) { + for (auto& tuple_element_sharding : tuple_elements()) { + auto unique_device = tuple_element_sharding.UniqueDevice(); + if (unique_device.ok()) { + device_map[unique_device.ValueOrDie()] += 1; + } + } + element_count = tuple_elements().size(); + } else { + auto unique_device = UniqueDevice(); + if (unique_device.ok()) { + device_map[unique_device.ValueOrDie()] += 1; + } + } + if (count != nullptr) { + *count = element_count; + } + return device_map; +} + std::vector HloSharding::TileIndexForDevice(int64 device) const { CHECK(!ShapeUtil::IsTuple(tile_shape_)); CHECK(!maximal_); @@ -123,24 +174,49 @@ std::vector HloSharding::TileLimitForDevice(int64 device) const { return index; } +int64 HloSharding::RequiredLeaves(const Shape& shape) { + // Empty tuples have no leaf nodes as far as ShapeUtil and ShapeTree are + // concerned, but they do have a single tuple_elements_ entry since we want + // to allow empty tuple results to have sharding. + return ShapeUtil::IsEmptyTuple(shape) ? 1 : ShapeUtil::GetLeafCount(shape); +} + +Status HloSharding::CheckLeafCount(const Shape& shape) const { + int64 shape_leaves = RequiredLeaves(shape); + TF_RET_CHECK(shape_leaves == tuple_elements_.size()) + << "Shape " << ShapeUtil::HumanString(shape) << " has " << shape_leaves + << " leaf nodes while this sharding has " << tuple_elements_.size(); + return Status::OK(); +} + StatusOr> HloSharding::AsShapeTree( const Shape& shape) const { if (IsTuple()) { ShapeTree result(shape, HloSharding::Replicate()); - int64 num_leaves = result.leaf_count(); - TF_RET_CHECK(num_leaves == tuple_elements_.size()) - << "Shape " << ShapeUtil::HumanString(shape) << " has " << num_leaves - << " leaf nodes while this sharding has " << tuple_elements_.size(); + TF_RETURN_IF_ERROR(CheckLeafCount(shape)); auto it = tuple_elements_.begin(); for (auto& index_to_sharding : result.leaves()) { index_to_sharding.second = *it++; } + if (ShapeUtil::IsEmptyTuple(shape)) { + // Empty tuples have no leaves, but we want to assign them a sharding + // anyway, so we use the root element sharding. + *result.mutable_element(ShapeIndex({})) = *it; + } return std::move(result); } else { return ShapeTree(shape, *this); } } +StatusOr HloSharding::GetTupleSharding(const Shape& shape) const { + if (IsTuple()) { + TF_RETURN_IF_ERROR(CheckLeafCount(shape)); + return *this; + } + return Tuple(ShapeTree(shape, *this)); +} + StatusOr HloSharding::UniqueDevice() const { if (IsTuple()) { if (tuple_elements_.empty()) { @@ -182,28 +258,12 @@ Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const { return tensorflow::errors::InvalidArgument( StrCat("Sharding is tuple-shaped but validation shape is not.")); } - // The easiest way to get the number of elements in a nested tuple is just to - // create a shape tree. We could call GetAsShapeTree, but that will try and - // apply our tuple_shardings_ to the shape tree, and that might cause a crash - // at this point as we haven't validated them. - ShapeTree bool_shape_tree(shape, false); - int64 num_leaves = - std::distance(bool_shape_tree.leaf_begin(), bool_shape_tree.leaf_end()); - if (num_leaves != tuple_elements_.size()) { - return tensorflow::errors::InvalidArgument( - StrCat("Validation tuple shape has ", num_leaves, - " leaf elements, but this sharding contains ", - tuple_elements_.size(), " elements.")); - } + TF_RETURN_IF_ERROR(CheckLeafCount(shape)); // Now we've validated the number of tuple elements, it's safe to request a // shape tree. ShapeTree shape_tree = GetAsShapeTree(shape); for (const auto& index_to_sharding : shape_tree.leaves()) { - if (index_to_sharding.first.empty()) { - // An empty tuple has a ShapeTree with a single leaf at the empty index. - continue; - } Status status = index_to_sharding.second.ValidateNonTuple( ShapeUtil::GetSubshape(shape, index_to_sharding.first), num_devices); if (!status.ok()) { @@ -389,6 +449,40 @@ HloSharding HloSharding::GetSubSharding(const Shape& shape, : sub_shape_tree.element(ShapeIndex({})); } +tensorflow::gtl::optional HloSharding::ExtractSingleSharding() + const { + if (!IsTuple()) { + return *this; + } + for (int64 i = 1; i < tuple_elements_.size(); ++i) { + if (tuple_elements_[0] != tuple_elements_[i]) { + return tensorflow::gtl::optional(); + } + } + return tuple_elements_.front(); +} + +size_t HloSharding::Hash() const { + if (!tuple_) { + size_t h = 0; + for (const auto& element : tuple_elements_) { + h = tensorflow::Hash64Combine(h, element.Hash()); + } + return h; + } + if (replicated_) { + return 0; + } + size_t h = 0; + for (uint32 v : tile_assignment_) { + h = tensorflow::Hash64Combine(h, std::hash{}(v)); + } + for (uint32 v : tile_shape_.dimensions()) { + h = tensorflow::Hash64Combine(h, std::hash{}(v)); + } + return h; +} + std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) { out << sharding.ToString(); return out; diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index f4a0fb626f2c3e417c020cbfa2f7168359a47788..34324d2058efe804cda486600dabd8a62cb84fda 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -19,7 +19,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ +#include #include +#include #include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -70,26 +72,13 @@ class HloSharding { // Creates a new sharding for a tuple type. The given ShapeTree must have // elements for every leaf shape contained in the tuple. - static HloSharding Tuple(const ShapeTree& sub_shardings) { - std::vector flattened_list; - flattened_list.reserve( - std::distance(sub_shardings.leaf_begin(), sub_shardings.leaf_end())); - for (const auto& index_to_sharding : sub_shardings.leaves()) { - flattened_list.push_back(index_to_sharding.second); - } - return HloSharding(flattened_list); - } + static HloSharding Tuple(const ShapeTree& sub_shardings); - // Creates a new sharding for a tuple type. The requested tuple shape must not - // be nested. For nested tuples, use the ShapeTree overload. + // Creates a new sharding for a tuple type. The number of elements in + // shardings must match the number of leaf nodes in tuple_shape. For + // empty tuples, the shardings array must have one element. static HloSharding Tuple(const Shape& tuple_shape, - tensorflow::gtl::ArraySlice shardings) { - CHECK(ShapeUtil::IsTuple(tuple_shape)); - CHECK(!ShapeUtil::IsNestedTuple(tuple_shape)); - std::vector flattened_list(shardings.begin(), shardings.end()); - CHECK_EQ(flattened_list.size(), ShapeUtil::TupleElementCount(tuple_shape)); - return HloSharding(flattened_list); - } + tensorflow::gtl::ArraySlice shardings); // Create a new sharding from a protobuf OpSharding. static StatusOr FromProto(const OpSharding& proto); @@ -131,6 +120,14 @@ class HloSharding { // Returns true if the sharding defines an operation on the given device. bool UsesDevice(int64 device) const; + // Retrieves an histogram of the devices used by the sharding. The returned + // map has the device number as key, and the occurrence count as value. + // If a sharding does not have a device, it will not be incuded in the + // histogram. The count argument, if not nullptr, will receive the total + // number of elements this sharding is made of (one for array, N leaves for + // tuples). + std::map UsedDevices(int64* count) const; + // Returns the tile that should be executed on the given device. // REQUIRES: !IsTuple() std::vector TileIndexForDevice(int64 device) const; @@ -172,6 +169,18 @@ class HloSharding { // REQUIRES: IsTuple() HloSharding GetSubSharding(const Shape& shape, const ShapeIndex& index) const; + // If the current sharding is a tuple sharding, return itself as result. + // Otherwise returns a tuple sharding for the input shape, with all the leaves + // having this object sharding. + StatusOr GetTupleSharding(const Shape& shape) const; + + // Extracts the sharding that is common within the current sharding. + // If the current sharding is not a tuple sharding, the current sharding will + // be returned. If it is a tuple, and all the tuple elements are common, the + // common element will be returned. Otherwise the optional will contain no + // value. + tensorflow::gtl::optional ExtractSingleSharding() const; + bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && ShapeUtil::Compatible(tile_shape_, other.tile_shape_) && @@ -180,26 +189,7 @@ class HloSharding { } bool operator!=(const HloSharding& other) const { return !(*this == other); } - size_t Hash() const { - if (!tuple_) { - size_t h = 0; - for (const auto& element : tuple_elements_) { - h = tensorflow::Hash64Combine(h, element.Hash()); - } - return h; - } - if (replicated_) { - return 0; - } - size_t h = 0; - for (uint32 v : tile_assignment_) { - h = tensorflow::Hash64Combine(h, std::hash{}(v)); - } - for (uint32 v : tile_shape_.dimensions()) { - h = tensorflow::Hash64Combine(h, std::hash{}(v)); - } - return h; - } + size_t Hash() const; struct Hasher { size_t operator()(const HloSharding& sharding) const { @@ -241,6 +231,12 @@ class HloSharding { tuple_(false), tile_shape_(), tile_assignment_({0}) {} + // device_id values: + // -2: magic number to mean unassigned device, used by spatial partitioning + // -1: the id of the host + // 0 or positive: the id of a device + // NOTE(dimvar): -1 is needed for outside compilation. It can be removed once + // we have fully switched to the side-effect tokens. explicit HloSharding(int64 device_id) : replicated_(false), maximal_(true), @@ -260,11 +256,19 @@ class HloSharding { tile_assignment_({0}), tuple_elements_(tuple_shardings) {} + // Checks that the number of elements in tuple_elements_ is consistent with + // the tuple shape passes as argument. + Status CheckLeafCount(const Shape& shape) const; + // Internal helper to validate a tuple sharding. Status ValidateTuple(const Shape& shape, int64 num_devices) const; + // Internal helper to validate a non-tuple (leaf) sharding. Status ValidateNonTuple(const Shape& shape, int64 num_devices) const; + // Returns the number of tuple_elements_ entries to fit the shape. + static int64 RequiredLeaves(const Shape& shape); + bool replicated_; bool maximal_; bool tuple_; diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index 82cff2a4b7146c2d454feb2d90673d419ca1a54d..748273a43cecca7a9c7392bb84f0e4c7133cfb14 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -31,32 +31,22 @@ struct PassThrough { HloInstruction* operand = nullptr; }; -void SetDeviceSharding(HloInstruction* instruction, int64 device) { - VLOG(4) << " " << instruction->name() << " to device " << device; - instruction->set_device_sharding(device); -} - -tensorflow::gtl::optional ShardingUniqueDevice( - const HloSharding& sharding) { - if (sharding.IsTileMaximal()) { - auto device = sharding.UniqueDevice(); - if (device.ok()) { - return device.ValueOrDie(); - } - } - return tensorflow::gtl::optional(); +void SetSingleSharding(HloInstruction* instruction, + const HloSharding& sharding) { + VLOG(4) << " " << instruction->name() << " to " << sharding; + instruction->set_single_sharding(sharding); } bool ShardingMatches(const HloSharding& sharding1, const HloSharding& sharding2) { - auto device1 = ShardingUniqueDevice(sharding1); - if (device1) { - auto device2 = ShardingUniqueDevice(sharding2); - if (device2) { - return *device1 == *device2; + auto single_sharding1 = sharding1.ExtractSingleSharding(); + if (single_sharding1) { + auto single_sharding2 = sharding2.ExtractSingleSharding(); + if (single_sharding2) { + return *single_sharding1 == single_sharding2; } } - // Anything which is not tile maximal with unique device, gets a full sharding + // Anything which is not unique across all elements, gets a full sharding // compare. return sharding1 == sharding2; } @@ -119,21 +109,21 @@ Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain, std::unique_ptr CloneShardingForDomain( const HloSharding& sharding) { - auto device = ShardingUniqueDevice(sharding); - if (!device) { + auto single_sharding = sharding.ExtractSingleSharding(); + if (!single_sharding) { return MakeUnique(sharding); } - return MakeUnique(HloSharding::AssignDevice(*device)); + return MakeUnique(*single_sharding); } -Status ApplyDomainDeviceSharding(const DomainMetadata::Domain& domain, - int64 device) { - VLOG(4) << "Applying device " << device << " sharding"; +Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + VLOG(4) << "Applying " << sharding << " sharding"; for (HloInstruction* instruction : domain.instructions) { // We only change instructions without sharding, since otherwise we might // mess up with eventual HLO passes which has knowledge of it. if (!instruction->has_sharding()) { - SetDeviceSharding(instruction, device); + SetSingleSharding(instruction, sharding); } else { VLOG(4) << " " << instruction->name() << " already has sharding " << instruction->sharding(); @@ -186,12 +176,15 @@ StatusOr ApplyDomainShardingPass(const DomainMetadata::Domain& domain, const HloSharding* tuple_sharding = GetOperandSharding(tuple, domain, sharding); if (tuple_sharding != nullptr) { - TF_RET_CHECK(tuple_sharding->IsTuple()) << tuple->ToString(); - HloSharding sub_sharding = tuple_sharding->GetSubSharding( - tuple->shape(), {instruction->tuple_index()}); - VLOG(4) << " " << instruction->name() << " to sharding " - << sub_sharding; - instruction->set_sharding(sub_sharding); + if (tuple_sharding->IsTuple()) { + HloSharding sub_sharding = tuple_sharding->GetSubSharding( + tuple->shape(), {instruction->tuple_index()}); + VLOG(4) << " " << instruction->name() << " to sharding " + << sub_sharding; + instruction->set_sharding(sub_sharding); + } else { + SetSingleSharding(instruction, *tuple_sharding); + } ++assigned; } } else if (instruction->opcode() == HloOpcode::kTuple) { @@ -242,12 +235,29 @@ StatusOr ApplyDomainShardingPass(const DomainMetadata::Domain& domain, Status ApplyDomainSharding(const DomainMetadata::Domain& domain, const HloSharding& sharding) { - auto device = ShardingUniqueDevice(sharding); - if (device) { - // Shortcut the simple case. We have a unique device sharding, so we call - // the ApplyDomainDeviceSharding() API which will apply array or tuple - // shaped device sharding to the domain instructions. - return ApplyDomainDeviceSharding(domain, *device); + // Here is the place to call external sharding normalizers, which are + // implemented in other modules (ie, spatial partitioning). + // The signature of the external normalizer function should be something + // like: + // + // StatusOr Normalizer(const DomainMetadata::Domain&, + // const HloSharding& sharding); + // + // The function should return true if it has processed the domain + // normalization, false if domain was not one recognized by it, or an error. + // We will call the functions in order below, and fall back to local code if + // none of the external normalizers acted on the domain. + // External normalizers should not handle the cases that are already handled + // locally. + + // None of the external normalizers handled the domain sharding, try to see + // whether this is a single sharding first. + auto single_sharding = sharding.ExtractSingleSharding(); + if (single_sharding) { + // Shortcut the simple case. We have a unique sharding, so we call + // the ApplyDomainSingleSharding() API which will apply array or tuple + // shaped sharding to the domain instructions. + return ApplyDomainSingleSharding(domain, *single_sharding); } VLOG(1) << "Assigning non-trivial sharding " << sharding; for (;;) { diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index ee7133689b15348a18e6db9181199d5b25bf8143..54b7402b866361748d9eb35182b0bf486c4c9bdc 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -321,8 +321,10 @@ TEST_F(HloShardingTest, ParseHloString) { check(HloSharding::AssignDevice(2)); check(HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), Array4D({{{{0}, {1}}}}))); - // Empty tuple. - check(HloSharding::Tuple(ShapeUtil::MakeTupleShape({}), {})); + // Empty tuple. One sharding is required for empty tuples, as we need to be + // able to assign sharding to them, even though they have no leaves. + check(HloSharding::Tuple(ShapeUtil::MakeTupleShape({}), + {HloSharding::Replicate()})); { // Non-nested tuple. auto tuple_shape = diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 9cfd8a9bf74bc69ac40b1e0974d9e084d31071c9..1d6cd4cb2308fd09c7511e390a146a5224f253a3 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -426,6 +426,15 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) { gather->gather_dimension_numbers(), gather->gather_window_bounds())); } +Status ShapeVerifier::HandleGenerateToken(HloInstruction* token) { + std::vector operand_shapes; + for (const HloInstruction* operand : token->operands()) { + operand_shapes.push_back(&operand->shape()); + } + return CheckShape(token, + ShapeInference::InferGenerateTokenShape(operand_shapes)); +} + Status ShapeVerifier::CheckShape(const HloInstruction* instruction, const Shape& inferred_shape) { // If allow_mixed_precision_ is false, check if there are operands with @@ -791,6 +800,46 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { return Status::OK(); } +namespace { + +// Returns true if the given Shape has a TOKEN shape as any subshape. +bool ShapeContainsToken(const Shape& shape) { + bool contains_token = false; + ShapeUtil::ForEachSubshape( + shape, [&contains_token](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsToken(subshape)) { + contains_token = true; + } + }); + return contains_token; +} + +// Verifies that all types entering and exiting the entry computation are +// legal. For example, TOKEN types have no Literal representation and cannot be +// on the interface of the entry computation (parameters and root instruction). +Status VerifyEntryAndExitShapes(const HloModule& module) { + for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) { + HloInstruction* param = + module.entry_computation()->parameter_instruction(i); + if (ShapeContainsToken(param->shape())) { + return InternalError( + "Entry parameter %d is or contains a token shape: %s", i, + ShapeUtil::HumanString(param->shape()).c_str()); + } + } + if (ShapeContainsToken( + module.entry_computation()->root_instruction()->shape())) { + return InternalError( + "Entry root is or contains a token shape: %s", + ShapeUtil::HumanString( + module.entry_computation()->root_instruction()->shape()) + .c_str()); + } + return Status::OK(); +} + +} // namespace + StatusOr HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(VerifyHloStructure(module)); @@ -851,6 +900,8 @@ StatusOr HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get())); } + TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module)); + return false; } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 1392a78097aa026b2f7cffa2b0135402d3ca7ae5..7283b3e7dcdbed5be18a1da1571287cf0c089288 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -81,6 +81,7 @@ class ShapeVerifier : public DfsHloVisitor { HloInstruction* batch_norm_inference) override; Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleGather(HloInstruction* gather) override; + Status HandleGenerateToken(HloInstruction* token) override; Status FinishVisit(HloInstruction*) override { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index dc3bfce0c495bc40a2df7b985cab67e02a3e15ce..d7458c338e9f1df9fac90270845aae0b8f779ee2 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -169,6 +169,23 @@ string HumanReadableProfileBuilder::ToString() const { StrAppend(&s, table.MakeReport(CyclesToMicroseconds(total_cycles_))); } } + + if (total_bytes > 0) { + MetricTableReport table; + table.SetMetricName("MiB read+written"); + table.SetEntryName("ops"); + table.SetShowCategoryTable(); + for (const auto& op : op_infos_) { + MetricTableReport::Entry entry; + entry.text = op.name; + entry.short_text = op.short_name; + entry.category_text = op.category; + entry.metric = static_cast(op.bytes_accessed) / (1 << 20); + table.AddEntry(std::move(entry)); + } + StrAppend(&s, + table.MakeReport(static_cast(total_bytes) / (1 << 20))); + } return s; } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 8b3fa6c1572cf0ed91fc427722edcb23d8b8529d..1985d20578677ae68b244023c4640454b004bf49 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -28,6 +28,7 @@ namespace { using Analysis = IndexedArrayAnalysis; using UnknownArray = Analysis::UnknownArray; using ConstantArray = Analysis::ConstantArray; +using ReshapedArray = Analysis::ReshapedArray; using ScalarIndexedArray = Analysis::ScalarIndexedArray; using tensorflow::gtl::ArraySlice; using tensorflow::str_util::Join; @@ -52,6 +53,13 @@ string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) { "(constant ", ShapeUtil::HumanString(root->shape()), ")"); } + case Array::kReshaped: { + ReshapedArray* reshaped_array = root->as(); + return tensorflow::strings::StrCat( + "(reshape ", ToString(reshaped_array->operand(), print_constants), + " to ", ShapeUtil::HumanString(reshaped_array->shape()), ")"); + } + case Array::kScalarIndexedConstant: case Array::kScalarIndexed: { auto* indexed_array = root->as(); @@ -239,15 +247,40 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForGather( tensorflow::gtl::ArraySlice window_bounds, Array* source, Array* indices) { if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) { + VLOG(3) << "ComputeArrayForGather: indices are not scalar"; return nullptr; } CHECK_EQ(dim_numbers.gather_dims_to_operand_dims_size(), 1); - if (!c_binary_search(dim_numbers.elided_window_dims(), - dim_numbers.gather_dims_to_operand_dims(0))) { + + // We can also handle dim_numbers.elided_window_dims_size() == 0 here, should + // it become relevant. + + if (dim_numbers.elided_window_dims_size() != 1 || + dim_numbers.elided_window_dims(0) != + dim_numbers.gather_dims_to_operand_dims(0)) { + VLOG(3) << "ComputeArrayForGather: gather operations must elide " + "gather_dims_to_operand_dims[0] and " + "gather_dims_to_operand_dims[0] only"; return nullptr; } + // ScalarIndexedArray cannot represent gathers that "slice" along some + // dimensions -- for instance it cannot represent a gather that picks 5 [2,3] + // arrays from an array of size [7,4,6]. We check that condition down below: + + for (int64 i = 0, e = source->shape().dimensions_size(); i < e; i++) { + if (i != dim_numbers.elided_window_dims(0) && + source->shape().dimensions(i) != window_bounds[i]) { + VLOG(3) << "ComputeArrayForGather: window_bounds[" << i + << "] != source->shape().dimensions(" << i << ") -- " + << source->shape().dimensions(i) << " vs. " << window_bounds[i] + << " with dim_numbers.elided_window_dims(0) = " + << dim_numbers.elided_window_dims(0); + return nullptr; + } + } + int64 source_dim = dim_numbers.gather_dims_to_operand_dims(0); std::vector output_dims; for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { @@ -336,7 +369,11 @@ std::vector ComputeReshapePassthroughDimPairs( // result_subarray_size does not include the elements in the current // `result_dim` dimension (we multiply in result_shape[result_dim] at the // end of loop body) so candidate_operand_dim can never be zero. - CHECK_NE(candidate_operand_dim, 0); + CHECK_NE(candidate_operand_dim, 0) + << "result_dim = " << result_dim + << ", result_subarray_size = " << result_subarray_size + << ", result_shape = [" << Join(result_shape, ",") << "]" + << ", operand_shape = [" << Join(operand_shape, ",") << "]"; if (candidate_operand_dim != -1 && result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) { @@ -357,7 +394,7 @@ std::vector ComputeReshapePassthroughDimPairs( }); VLOG(3) << "For a reshape from [" << Join(operand_shape, ",") << "] to [" << Join(result_shape, ",") << "] passthrough indices are [" - << Join(result_strings, ",") << "]"; + << Join(result_strings, ",") << "] (legend: `result`->`operand`)"; } DCHECK(c_is_sorted( @@ -398,6 +435,10 @@ int64 MapPassthroughOperandDimToResultDim( int64 FindSourcePositionForPassthroughResultDim(ArraySlice operand_shape, ArraySlice result_shape, int64 source_passthrough_dim) { + VLOG(3) << "FindSourcePositionForPassthroughResultDim([" + << Join(operand_shape, ",") << "], [" << Join(result_shape, ",") + << "], " << source_passthrough_dim << ")"; + int64 indexed_source_subarray_size = std::accumulate(operand_shape.begin() + source_passthrough_dim + 1, operand_shape.end(), 1, std::multiplies()); @@ -405,15 +446,191 @@ int64 FindSourcePositionForPassthroughResultDim(ArraySlice operand_shape, return FindSuffixWithProduct(result_shape, indexed_source_subarray_size); } +Shape StripDegenerateDimensions(const Shape& shape) { + DimensionVector new_dims; + c_copy_if(shape.dimensions(), std::back_inserter(new_dims), + [](int64 dim) { return dim != 1; }); + return ShapeUtil::MakeShape(shape.element_type(), new_dims); +} }; // namespace -StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( - const Shape& shape, Array* operand) { - auto* scalar_indexed = dynamic_cast(operand); - if (!scalar_indexed) { +StatusOr +IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims( + ScalarIndexedArray* operand) { + const Shape& shape = operand->shape(); + if (!ShapeUtil::HasDegenerateDimensions(shape)) { + return operand; + } + + // We only need to reshape out the degenerate dims from the indices and the + // source (except the source dim). + + const Shape& source_shape = operand->source()->shape(); + DimensionVector new_source_shape_dims; + for (int64 i = 0, e = source_shape.dimensions_size(); i < e; i++) { + if (i == operand->source_dim() || source_shape.dimensions(i) != 1) { + new_source_shape_dims.push_back(source_shape.dimensions(i)); + } + } + + Shape new_source_shape = + ShapeUtil::MakeShape(shape.element_type(), new_source_shape_dims); + Shape new_indices_shape = + StripDegenerateDimensions(operand->indices()->shape()); + + TF_ASSIGN_OR_RETURN( + Array* const new_source, + ComputeArrayForReshape(new_source_shape, operand->source())); + TF_ASSIGN_OR_RETURN( + Array* const new_indices, + ComputeArrayForReshape(new_indices_shape, operand->indices())); + + // Build the new output dims while keeping track of the degenerate dims that + // will no longer be present. + DimensionVector new_output_dims; + int64 degenerate_dims_seen = 0; + for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { + if (shape.dimensions(i) == 1) { + degenerate_dims_seen++; + } else if (ArrayContains(operand->output_dims(), i)) { + new_output_dims.push_back(i - degenerate_dims_seen); + } + } + + // Similarly, build the new source dim while keeping track of the degenerate + // dims that will no longer be present. + int64 degenerate_dims_before_source_dim = + std::count(source_shape.dimensions().begin(), + source_shape.dimensions().begin() + operand->source_dim(), 1); + int64 new_source_dim = + operand->source_dim() - degenerate_dims_before_source_dim; + + return ConstructScalarIndexedArray( + new_source, new_indices, new_source_dim, + InlinedVectorToVector(new_output_dims), + StripDegenerateDimensions(operand->shape())); +} + +StatusOr IndexedArrayAnalysis::ReshapeToAddDegenerateDims( + ScalarIndexedArray* operand, + tensorflow::gtl::ArraySlice degenerate_dims) { + if (degenerate_dims.empty()) { + return operand; + } + + CHECK(!ShapeUtil::HasDegenerateDimensions(operand->shape())); + + DimensionVector new_output_dims = [&]() { + // To make things easy we use a "scratch" buffer of bools where the i'th + // element is true iff the i'th component of the result index is an output + // index. + + gtl::InlinedVector output_dims_bitvector( + operand->shape().dimensions_size()); + for (int64 output_dim : operand->output_dims()) { + output_dims_bitvector[output_dim] = true; + } + + for (int64 degenerate_dim : degenerate_dims) { + InsertAt(&output_dims_bitvector, degenerate_dim, false); + } + + DimensionVector result; + result.reserve(operand->output_dims().size()); + for (int64 i = 0, e = output_dims_bitvector.size(); i < e; i++) { + if (output_dims_bitvector[i]) { + result.push_back(i); + } + } + + return result; + }(); + + DimensionVector new_result_shape_dims; + c_copy(operand->shape().dimensions(), + std::back_inserter(new_result_shape_dims)); + for (int64 degenerate_dim : degenerate_dims) { + InsertAt(&new_result_shape_dims, degenerate_dim, 1); + } + + DimensionVector new_source_shape_dims = new_result_shape_dims; + for (int64 output_dim : new_output_dims) { + EraseAt(&new_source_shape_dims, output_dim); + } + + int64 new_source_dim = [&]() { + for (int i = 0, e = new_source_shape_dims.size(); i < e; i++) { + int64 non_degenerate_dims_seen = 0; + if (non_degenerate_dims_seen == operand->source_dim()) { + return i; + } + if (new_source_shape_dims[new_source_dim] != 1) { + non_degenerate_dims_seen++; + } + } + LOG(FATAL) << "Did not find source dim in " << ToString(operand); + }(); + + int64 source_dim_size = + operand->source()->shape().dimensions(operand->source_dim()); + InsertAt(&new_source_shape_dims, /*index=*/new_source_dim, + /*value=*/source_dim_size); + + Shape new_source_shape = ShapeUtil::MakeShape(operand->shape().element_type(), + new_source_shape_dims); + Shape new_result_shape = ShapeUtil::MakeShape(operand->shape().element_type(), + new_result_shape_dims); + + TF_ASSIGN_OR_RETURN( + Array* const new_source, + ComputeArrayForReshape(new_source_shape, operand->source())); + return ConstructScalarIndexedArray( + new_source, operand->indices(), new_source_dim, + InlinedVectorToVector(new_output_dims), new_result_shape); +} + +StatusOr IndexedArrayAnalysis::FoldReshapeOfGather( + const Shape& shape, ScalarIndexedConstantArray* operand) { + VLOG(3) << "FoldReshapeOfGather(" << ToString(operand) << ")"; + + // To make things easier on ourselves, instead of directly trying to fold the + // reshape of `operand` to `shape`, we call + // `FoldReshapeOfGatherNoDegenerateDims` on shapes without degenerate dims and + // handle the degenerate dimensions here by inserting reshapes. + + TF_ASSIGN_OR_RETURN(ScalarIndexedArray* const operand_without_degenerate_dims, + ReshapeToRemoveDegenerateDims(operand)); + + Shape output_shape_without_degenerate_dims = StripDegenerateDimensions(shape); + TF_ASSIGN_OR_RETURN( + ScalarIndexedArray* const folded_reshape_without_degenerate_dims, + FoldReshapeOfGatherNoDegenerateDims( + output_shape_without_degenerate_dims, + operand_without_degenerate_dims->as())); + + if (folded_reshape_without_degenerate_dims == nullptr) { return nullptr; } + DimensionVector degenerate_result_dims; + for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { + if (shape.dimensions(i) == 1) { + degenerate_result_dims.push_back(i); + } + } + + return ReshapeToAddDegenerateDims(folded_reshape_without_degenerate_dims, + degenerate_result_dims); +} + +StatusOr +IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( + const Shape& shape, ScalarIndexedConstantArray* scalar_indexed) { + VLOG(3) << "FoldReshapeOfGatherNoDegenerateDims(" << ToString(scalar_indexed) + << ")"; + CHECK(!ShapeUtil::HasDegenerateDimensions(shape)); + CHECK(!ShapeUtil::HasDegenerateDimensions(scalar_indexed->shape())); + // Try to fold Reshape(ScalarIndexed(Const, Indices)) // => ScalarIndexed(Const', Indices) // @@ -464,7 +681,7 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( std::vector reshape_passthrough_dims = ComputeReshapePassthroughDimPairs( - /*operand_shape=*/AsInt64Slice(operand->shape().dimensions()), + /*operand_shape=*/AsInt64Slice(scalar_indexed->shape().dimensions()), /*result_shape=*/AsInt64Slice(shape.dimensions())); auto is_reshape_passthrough_operand_dim = [&](int64 operand_dim) { @@ -474,6 +691,8 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( if (!c_all_of(scalar_indexed->output_dims(), is_reshape_passthrough_operand_dim)) { + VLOG(3) << "Not all output dims are passthrough dims " + << ToString(scalar_indexed); return nullptr; } @@ -527,6 +746,11 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( // (a.k.a. isn't pass-through) than the [3,5,2] array. if (source_dim_for_new_scalar_indexed_node == -1) { + VLOG(3) << "Could not compute the source dim for the new scalar indexed " + "node: scalar_indexed_source_shape = [" + << Join(scalar_indexed_source_shape.dimensions(), ",") + << "] and new_scalar_indexed_source_shape = [" + << Join(new_scalar_indexed_source_shape, ",") << "]"; return nullptr; } @@ -534,6 +758,10 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( &new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node, scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim())); + CHECK_EQ(c_accumulate(new_scalar_indexed_source_shape, 1l, + std::multiplies()), + ShapeUtil::ElementsIn(scalar_indexed_source_shape)); + CHECK(IsReshapePassthroughOperandDim( ComputeReshapePassthroughDimPairs( /*operand_shape=*/AsInt64Slice( @@ -564,6 +792,31 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( output_dims_for_new_scalar_indexed_node, shape); } +StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( + const Shape& shape, Array* operand) { + if (ShapeUtil::Compatible(operand->shape(), shape)) { + return operand; + } + + if (auto* scalar_indexed = + dynamic_cast(operand)) { + TF_ASSIGN_OR_RETURN(Analysis::Array * reshape_folded_into_gather, + FoldReshapeOfGather(shape, scalar_indexed)); + if (reshape_folded_into_gather) { + return reshape_folded_into_gather; + } + } + + if (auto* constant_array = dynamic_cast(operand)) { + TF_ASSIGN_OR_RETURN(Literal* const new_literal, + TakeOwnership(constant_array->literal()->Reshape( + AsInt64Slice(shape.dimensions())))); + return Construct(new_literal); + } + + return Construct(operand, shape); +} + StatusOr IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, Array* lhs, diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index ce92fd2919c90fa8a2fb7b796ed6f0fdaf48fe62..8684430231c1929f82508e3675f1c275c42b6149 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -39,7 +39,13 @@ class IndexedArrayAnalysis { // Array instances are immutable once created. class Array { public: - enum Kind { kUnknown, kConstant, kScalarIndexedConstant, kScalarIndexed }; + enum Kind { + kUnknown, + kConstant, + kReshaped, + kScalarIndexedConstant, + kScalarIndexed + }; virtual Kind kind() const = 0; virtual const Shape& shape() const = 0; @@ -96,6 +102,27 @@ class IndexedArrayAnalysis { friend class IndexedArrayAnalysis; }; + // Represents an Array that is a reshape of another Array. + class ReshapedArray : public Array { + public: + Kind kind() const override { return kReshaped; } + + // The array to reshape. + Array* operand() const { return operand_; } + + // The output shape. + const Shape& shape() const override { return shape_; } + + private: + explicit ReshapedArray(Array* operand, Shape shape) + : operand_(operand), shape_(shape) {} + + Array* operand_; + const Shape shape_; + + friend class IndexedArrayAnalysis; + }; + // --------------------------------------------------------------------------- // Indexed Array Overview // --------------------------------------------------------------------------- @@ -266,6 +293,21 @@ class IndexedArrayAnalysis { ScalarIndexedArray* source, Array* indices, int64 source_dim, tensorflow::gtl::ArraySlice output_dims, Shape shape); + // Reshapes a scalar-indexed node to remove the degenerate dimensions in its + // output. The result is always a scalar-indexed node. + StatusOr ReshapeToRemoveDegenerateDims( + ScalarIndexedArray* operand); + + // Reshapes a scalar-indexed node such that the result has the degenerate + // dimensions `degenerate_dims`. The result is always a scalar-indexed node. + StatusOr ReshapeToAddDegenerateDims( + ScalarIndexedArray* operand, + tensorflow::gtl::ArraySlice degenerate_dims); + + StatusOr FoldReshapeOfGather( + const Shape& shape, ScalarIndexedConstantArray* operand); + StatusOr FoldReshapeOfGatherNoDegenerateDims( + const Shape& shape, ScalarIndexedConstantArray* scalar_indexed); StatusOr ComputeArrayForReshape(const Shape& shape, Array* operand); StatusOr ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 373556ebeba883f7dc2116bdf0ffc3274182f775..fc2befe05b18651502c42b9892e766145d85f2e8 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" @@ -34,6 +36,27 @@ class IndexedArrayAnalysisTest : public HloVerifiedTestBase { } private: + // Replaces seqences of whitespace with a single space. This makes the + // strings being matched against "whitespace insensitive" which lets us indent + // them for readability. + string CanonicalizeWhitespace(const string& text) { + string result; + + for (char c : text) { + if (!isspace(c)) { + result.push_back(c); + } else if (!result.empty() && result.back() != ' ') { + result.push_back(' '); + } + } + + while (!result.empty() && result.back() == ' ') { + result.pop_back(); + } + + return result; + } + void AssertArrayForRootExpressionIsImpl(const string& hlo_text, const string& root_expression, bool print_constants) { @@ -44,10 +67,10 @@ class IndexedArrayAnalysisTest : public HloVerifiedTestBase { IndexedArrayAnalysis::Array* const array_result, indexed_tensor_analysis.GetArrayFor( module().entry_computation()->root_instruction())); - string string_result = - indexed_tensor_analysis.ToString(array_result, print_constants); + string string_result = CanonicalizeWhitespace( + indexed_tensor_analysis.ToString(array_result, print_constants)); LOG(INFO) << string_result; - ASSERT_EQ(string_result, root_expression); + ASSERT_EQ(string_result, CanonicalizeWhitespace(root_expression)); } }; @@ -91,6 +114,82 @@ ENTRY main { hlo_text, "(scalar-indexed-const (constant s32[3,3]) %indices 0->[0])"); } +TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed0) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + indices = s32[5,2] parameter(0) + ROOT gather = s32[5] gather(operand, indices), + output_window_dims={}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%gather"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed1) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3,1] parameter(0) + indices = s32[5] parameter(1) + ROOT gather = s32[5,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,2}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3,1} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%gather"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed2) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3,1] parameter(0) + indices = s32[5] parameter(1) + ROOT gather = s32[5,2,3] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={2}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={2,3,1} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%gather"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed3) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[5] parameter(1) + ROOT gather = s32[5,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,2} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%gather"); +} + TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOne) { string hlo_text = R"( HloModule SimpleGather @@ -273,7 +372,157 @@ ENTRY main { "(scalar-indexed-const (constant s32[3,3,4]) %indices 0->[0,3])"); } -TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNegative0) { +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather3) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[2,6] constant(s32[2,6]{ + {1,2,3,4,5,6},{1,2,3,4,5,6}}) + indices = s32[1] parameter(0) + gather = s32[1,6] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,6} + ROOT reshape = s32[1,1,6] reshape(gather) +} +)"; + + const char* expected_root_expression = R"( +(scalar-indexed-const + (constant s32[2,1,1,6]) + (reshape %indices to s32[]) + 0->[]) +)"; + + AssertArrayForRootExpressionIs(hlo_text, expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather4) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 1, 2, 3 } }) + + i.0 = s64[1,3]{1,0} parameter(0) + g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), output_window_dims={2}, + elided_window_dims={0}, gather_dims_to_operand_dims={0}, + index_vector_dim=2, window_bounds={1,3} + + i.1 = s64[1] parameter(1) + g.1 = s32[1,1,3]{2,1,0} gather(g.0, i.1), output_window_dims={0,2}, + elided_window_dims={1}, gather_dims_to_operand_dims={1}, + index_vector_dim=1, window_bounds={1,1,3} + + ROOT reshape = s32[1,3]{1,0} reshape(g.1) +} +)"; + + const char* expected_root_expression = R"( +(scalar-indexed-const + (constant s32[2,1,3]) + (reshape + (scalar-indexed %i.0 %i.1 1->[1]) + to s64[]) + 0->[]) +)"; + + AssertArrayForRootExpressionIs(hlo_text, expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather5) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[1,6] constant(s32[1,6]{{1,2,3,4,5,6}}) + indices = s32[1] parameter(0) + gather = s32[1,6] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,6} + ROOT reshape = s32[1,1,6] reshape(gather) +} +)"; + + const char* expected_root_expression = R"( +(scalar-indexed-const + (constant s32[1,1,1,6]) + (reshape %indices to s32[]) + 0->[]) +)"; + + AssertArrayForRootExpressionIs(hlo_text, expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather6) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[1,2,6] constant(s32[1,2,6]{{ + {1,2,3,4,5,6},{1,2,3,4,5,6}}}) + indices = s32[1] parameter(0) + gather = s32[1,1,6] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={1,1,6} + ROOT reshape = s32[1,1,1,6] reshape(gather) +} +)"; + + const char* expected_root_expression = R"( +(scalar-indexed-const + (constant s32[2,1,1,1,6] s32[2,1,1,1,6] { + { /*i0=0*/ { /*i1=0*/ { /*i2=0*/ {1, 2, 3, 4, 5, 6} } } }, + { /*i0=1*/ { /*i1=0*/ { /*i2=0*/ {1, 2, 3, 4, 5, 6} } } } }) + (reshape %indices to s32[]) + 0->[]) +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, + expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather7) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[2,6] constant(s32[2,6]{ + {1,2,3,4,5,6},{1,2,3,4,5,6}}) + indices = s32[1,5] parameter(0) + gather = s32[1,5,6] gather(operand, indices), + output_window_dims={2}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,6} + ROOT reshape = s32[1,1,5,6] reshape(gather) +} +)"; + + const char* expected_root_expression = R"( +(scalar-indexed-const + (constant s32[2,1,1,6] s32[2,1,1,6] { + { /*i0=0*/ { /*i1=0*/ {1, 2, 3, 4, 5, 6} } }, + { /*i0=1*/ { /*i1=0*/ {1, 2, 3, 4, 5, 6} } } }) + (reshape %indices to s32[5]) + 0->[2]) +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, + expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold0) { string hlo_text = R"( HloModule ReshapeOfGather @@ -290,10 +539,19 @@ ENTRY main { } )"; - AssertArrayForRootExpressionIs(hlo_text, "%reshape"); + const char* expected_root_expression = R"( +(reshape + (scalar-indexed-const + (constant s32[3,4]) + %indices + 0->[0,2]) + to s32[5,2,2,2,3]) +)"; + + AssertArrayForRootExpressionIs(hlo_text, expected_root_expression); } -TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNegative1) { +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold1) { string hlo_text = R"( HloModule ReshapeOfGather @@ -313,7 +571,48 @@ ENTRY main { } )"; - AssertArrayForRootExpressionIs(hlo_text, "%reshape"); + const char* expected_root_expression = R"( +(reshape + (scalar-indexed-const + (constant s32[3,5,2]) + %indices + 1->[2]) + to s32[6,7]) +)"; + + AssertArrayForRootExpressionIs(hlo_text, expected_root_expression); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold2) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,4,1] constant(s32[3,4,1]{ + {{1},{2},{3},{4}}, + {{1},{2},{3},{4}}, + {{1},{2},{3},{4}}}) + indices = s32[5,6] parameter(0) + gather = s32[5,4,6,1] gather(operand, indices), + output_window_dims={1,3}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,4,1} + ROOT reshape = s32[5,2,2,2,3,1] reshape(gather) +} +)"; + + const char* expected_root_expression = R"( +(reshape + (scalar-indexed-const + (constant s32[3,4,1]) + %indices + 0->[0,2]) + to s32[5,2,2,2,3,1]) +)"; + + AssertArrayForRootExpressionIs(hlo_text, expected_root_expression); } TEST_F(IndexedArrayAnalysisTest, UnaryOpOfGather) { diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 429c8503432b79f46aa0e5b1970bb565093128dd..d1c4c91b34a71cddd90022b91d6e105aa932f402 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -96,6 +96,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kShiftRightLogical: case HloOpcode::kSlice: case HloOpcode::kSubtract: + case HloOpcode::kGenerateToken: case HloOpcode::kTranspose: case HloOpcode::kTuple: return false; @@ -280,10 +281,8 @@ StatusOr InstructionFusion::Run(HloModule* module) { // map from HloInstruction* to the instruction's index in the vector. An // instruction is "removed" from the vector by setting it's element to // nullptr. - std::list post_order_list = + std::vector post_order = computation_->MakeInstructionPostOrder(); - std::vector post_order(post_order_list.begin(), - post_order_list.end()); tensorflow::gtl::FlatMap post_order_index; for (size_t i = 0; i < post_order.size(); ++i) { diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index c1666530687f2f8407a9dcb4e271c9d95552a689..9f8f4bda875cdff5e20fa8ca8eeecaa1140e2b9c 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -44,7 +44,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); pipeline.AddPass( - hlo_module->mutable_device_entry_computation_layout()); + hlo_module->mutable_entry_computation_layout()); return pipeline.Run(hlo_module).status(); } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 029e71058a7373b9310c6d9ffdb65f72ca28e5af..9816acf6507a0ed5391cf4f1c94ccd0f27f5227a 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -75,9 +75,9 @@ StatusOr InterpreterExecutable::ExecuteOnStream( // consumes. std::vector> arg_literals; for (int64 p = 0; p < computation->num_parameters(); ++p) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr arg_literal, - transfer_manager->TransferLiteralFromDevice(executor, *arguments[p])); + TF_ASSIGN_OR_RETURN(std::unique_ptr arg_literal, + transfer_manager->TransferLiteralFromDevice( + run_options->stream(), *arguments[p])); arg_literals.push_back(std::move(arg_literal)); } @@ -96,7 +96,7 @@ StatusOr InterpreterExecutable::ExecuteOnStream( result_literal->shape(), run_options->allocator(), executor->device_ordinal())); TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( - executor, *result_literal, result)); + run_options->stream(), *result_literal, result)); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); diff --git a/tensorflow/compiler/xla/service/interpreter/executor.cc b/tensorflow/compiler/xla/service/interpreter/executor.cc index 97e9fa2c8e8ecd918ffe3df2fd4e731f3b91e6db..4fb67bd0b72fc591c1ffa76ebb0513bf14ed3737 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.cc +++ b/tensorflow/compiler/xla/service/interpreter/executor.cc @@ -53,6 +53,7 @@ bool XlaInterpreterExecutor::Memcpy(Stream *stream, void *host_dst, AsExecutorStream(stream)->EnqueueTask([this, host_dst, dev_src, size]() { port::Status ok = SynchronousMemcpy(host_dst, dev_src, size); }); + AsExecutorStream(stream)->BlockUntilDone(); return true; } @@ -61,6 +62,7 @@ bool XlaInterpreterExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst, AsExecutorStream(stream)->EnqueueTask([this, dev_dst, host_src, size]() { port::Status ok = SynchronousMemcpy(dev_dst, host_src, size); }); + AsExecutorStream(stream)->BlockUntilDone(); return true; } diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 7067b6f86a0fb24fb946ad236bca9bbd48d53722..b319518421339e0b4ee284f81aa30304d7689298 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -175,41 +175,32 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, TF_RETURN_IF_ERROR( LayoutUtil::ValidateLayoutForShape(layout, buffer.shape())); - const BufferLayoutConstraint* curr_constraint = - GetBufferLayoutConstraint(buffer); - if (curr_constraint != nullptr) { - if (LayoutUtil::Equal(curr_constraint->layout(), layout)) { + auto iter = buffer_constraints_.find(&buffer); + if (iter != buffer_constraints_.end()) { + const BufferLayoutConstraint& curr_constraint = iter->second; + if (LayoutUtil::Equal(curr_constraint.layout(), layout)) { // New constraint matches existing constraint. Nothing to do. return Status::OK(); } - if (curr_constraint->mandatory()) { + if (curr_constraint.mandatory()) { return FailedPrecondition( "Buffer %s already has the layout constraint %s, cannot add " "incompatible constraint %s", buffer.ToString().c_str(), - LayoutUtil::HumanString(curr_constraint->layout()).c_str(), + LayoutUtil::HumanString(curr_constraint.layout()).c_str(), LayoutUtil::HumanString(layout).c_str()); } - } - - auto iter = buffer_constraints_.find(&buffer); - bool overwrite = iter != buffer_constraints_.end(); - if (!overwrite) { + iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs); + } else { + TF_RET_CHECK(unconstrained_buffer_ids_.erase(buffer.id()) == 1) + << buffer.ToString(); iter = buffer_constraints_ .insert(std::make_pair( &buffer, BufferLayoutConstraint(layout, buffer, mandatory, dfs))) .first; - } else { - iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs); } added_constraints_.push_back(&iter->second); - - // Remove buffer from the set of unconstrained buffers. - TF_RET_CHECK(unconstrained_buffer_ids_.count(buffer.id()) == - static_cast(!overwrite)); - unconstrained_buffer_ids_.erase(buffer.id()); - return Status::OK(); } @@ -716,7 +707,8 @@ Status CheckParameterLayout(HloInstruction* parameter, const ComputationLayout& computation_layout) { const ShapeLayout& parameter_layout = computation_layout.parameter_layout(parameter->parameter_number()); - if (!parameter_layout.MatchesLayoutInShape(parameter->shape())) { + if (parameter_layout.LayoutIsSet() && + !parameter_layout.MatchesLayoutInShape(parameter->shape())) { return InternalError( "parameter instruction %s does not match layout of computation " "shape: %s", @@ -936,14 +928,15 @@ LayoutAssignment::LayoutAssignment( ComputationLayout* entry_computation_layout, ChannelLayoutConstraints* channel_constraints) : entry_computation_layout_(entry_computation_layout), + saved_entry_computation_layout_(*entry_computation_layout), channel_layout_constraints_(channel_constraints) { + if (channel_layout_constraints_ != nullptr) { + // Save a copy of the input ChannelLayoutConstraints so that we can reset it + // if we have to undo previous operations (ClearPreviousPassSideEffects()). + channel_constraints_ = *channel_layout_constraints_; + } VLOG(1) << "Entry computation layout given to layout assignment: " << entry_computation_layout_->ToString(); - // Layouts of all parameter instructions must be set. - for (const ShapeLayout& parameter_layout : - entry_computation_layout_->parameter_layouts()) { - CHECK(parameter_layout.LayoutIsSet()); - } } std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( @@ -1614,13 +1607,57 @@ Status LayoutAssignment::RunOnComputation( // Record the layouts assigned for any communication ops in // channel_constraints so that they are constrained for future modules. + if (channel_constraints != nullptr) { + TF_RETURN_IF_ERROR( + ConstrainChannelLayouts(computation, channel_constraints)); + } + return Status::OK(); +} + +Status LayoutAssignment::ConstrainChannelLayouts( + HloComputation* computation, + ChannelLayoutConstraints* channel_constraints) { + // We go through the kRecvDone before. These must either impose their layout, + // of find a matching one already existing (ConstrainChannel() returns + // nullptr). for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kRecvDone) { + const Layout* layout = channel_constraints->ConstrainChannel( + instruction->channel_id(), instruction->shape().layout()); + TF_RET_CHECK(layout == nullptr) + << instruction->ToString() + << " cannot constrain layout as it was set to " + << LayoutUtil::HumanString(*layout); + } + } + // After that we go through the kSend. These are likely going to have a kCopy + // as operand (otherwise we add it), so in case the constrained layout does + // not match, we can change the kCopy layout (and the kSend one as well). + for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { if (instruction->opcode() == HloOpcode::kSend) { - channel_constraints->ConstrainChannel( - instruction->channel_id(), instruction->operand(0)->shape().layout()); - } else if (instruction->opcode() == HloOpcode::kRecvDone) { - channel_constraints->ConstrainChannel(instruction->channel_id(), - instruction->shape().layout()); + HloInstruction* operand = instruction->mutable_operand(0); + const Layout* layout = channel_constraints->ConstrainChannel( + instruction->channel_id(), operand->shape().layout()); + if (layout != nullptr) { + // We found an already constrained layout which does not match the one + // the kSend wants to impose. Eitehr add a new kCopy, or use the + // existing one to marshal the correct shape. + Shape shape = operand->shape(); + *shape.mutable_layout() = *layout; + if (operand->opcode() != HloOpcode::kCopy) { + HloInstruction* copy = operand->parent()->AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand)); + RegisterAddedCopy(copy); + SetupCopiedInstruction(*operand, copy, {}); + TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy)); + operand = copy; + } else { + *operand->mutable_shape() = shape; + } + Shape* send_shape = + ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), {0}); + *send_shape = shape; + } } } return Status::OK(); @@ -1679,6 +1716,7 @@ StatusOr LayoutAssignment::Run(HloModule* module) { // root, we also fix up the eventually inconsistent ComputationLayout, which // will be then made mandatory by the second pass. for (int64 i = 0; i < 2; ++i) { + VLOG(5) << "Running " << (i == 0 ? "un" : "") << "constrained pass"; TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module)); TF_ASSIGN_OR_RETURN(auto points_to_analysis, TuplePointsToAnalysis::Run(module)); @@ -1716,10 +1754,12 @@ StatusOr LayoutAssignment::Run(HloModule* module) { Status LayoutAssignment::Init() { computation_layouts_.clear(); + *entry_computation_layout_ = saved_entry_computation_layout_; return Status::OK(); } Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) { + VLOG(5) << "Clearing previous side effects"; // Clear all the copies which have been added, and all the related // instructions (like GTE and tuples). int64 removed_copies = 0; @@ -1743,6 +1783,7 @@ Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) { TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); TF_RETURN_IF_ERROR(dce.Run(module).status()); } + ResetChannelConstraints(); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index c287cca0c54ba1bb514bd8d243c137eca99b258f..0d7dde9c553333c6aad8e06b2712ba4c127f7762 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -249,25 +249,30 @@ class ChannelLayoutConstraints { // Given `shape`, apply the layout for `channel_id`. `channel_id` must already // be constrained. Shape LayoutShapeForChannel(Shape shape, int64 channel_id) const { - CHECK(IsChannelConstrained(channel_id)); - *shape.mutable_layout() = constraints_.at(channel_id); + auto it = constraints_.find(channel_id); + CHECK(it != constraints_.end()) << "Channel " << channel_id; + *shape.mutable_layout() = it->second; return shape; } // Returns the layout constraint for `channel_id`, which must already be // constrained. - Layout LayoutForChannel(int64 channel_id) const { - CHECK(IsChannelConstrained(channel_id)); - return constraints_.at(channel_id); + const Layout& LayoutForChannel(int64 channel_id) const { + auto it = constraints_.find(channel_id); + CHECK(it != constraints_.end()) << "Channel " << channel_id; + return it->second; } // Adds a new layout constraint for `channel_id`. If a constraint for - // `channel_id` already exists, this operation requires that the new layout is - // the same as the previously constrained layout. - void ConstrainChannel(int64 channel_id, const Layout& layout) { - CHECK(!IsChannelConstrained(channel_id) || - LayoutUtil::Equal(layout, constraints_[channel_id])); - constraints_[channel_id] = layout; + // `channel_id` has been added, this API returns nullptr, otherwise returns + // the layout which has already been set for the channel. + const Layout* ConstrainChannel(int64 channel_id, const Layout& layout) { + auto it = constraints_.emplace(std::make_pair(channel_id, layout)); + if (it.second) { + return nullptr; + } + return LayoutUtil::Equal(layout, it.first->second) ? nullptr + : &it.first->second; } private: @@ -427,8 +432,13 @@ class LayoutAssignment : public HloPassInterface { Status PropagateComputationLayouts(HloComputation* computation, ComputationLayout* computation_layout); + // The pointer to the ComputationLayout passed as constructor parameter. ComputationLayout* entry_computation_layout_; + // A copy of entry_computation_layout_ used to reset it to the initial values + // during the multiple passes done by the layout assignment operation. + ComputationLayout saved_entry_computation_layout_; + protected: // Sets up the copy instruction according to the characteristic (sharding, // metadata, ...) of the reference instruction. The index argument is used @@ -464,6 +474,20 @@ class LayoutAssignment : public HloPassInterface { // itself). Status AddCopyForOperand(HloInstruction* instruction, int64 operand_number); + // Apply the channel layout constraints by populating the channel_constraints + // data structure passed in at constructor time. Eventually adds copies in + // case two ends of a channel ended up with a different leyout. + Status ConstrainChannelLayouts(HloComputation* computation, + ChannelLayoutConstraints* channel_constraints); + + // Resets the input ChannelLayoutConstraints to the original copy received + // from the constructor input. + void ResetChannelConstraints() { + if (channel_layout_constraints_ != nullptr) { + *channel_layout_constraints_ = channel_constraints_; + } + } + // Map containing the layouts of all computations assigned so // far. Computations are handled in a topological sort where computations are // handled before their caller instructions so the layouts of caller @@ -474,7 +498,14 @@ class LayoutAssignment : public HloPassInterface { // here. tensorflow::gtl::FlatSet added_copies_; - ChannelLayoutConstraints* channel_layout_constraints_; + // The pointer to the channel layout constraints passed in with the + // constructor. If not nullptr, this is an input/output argument. + ChannelLayoutConstraints* channel_layout_constraints_ = nullptr; + + // A copy of the input layout constraints used to reset the above pointer in + // case we have to undo operations due to the multiple passes over the + // computations/instructions. + ChannelLayoutConstraints channel_constraints_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index bf0448a67674f24591d866b646b98aea09ebb12c..62599b376a12808232c703479a0ccfd7a59aa9ad 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -52,10 +52,18 @@ using ::testing::ElementsAre; class LayoutAssignmentTest : public HloTestBase { protected: void AssignLayouts(HloModule* module, - ComputationLayout* entry_computation_layout) { - LayoutAssignment layout_assignment(entry_computation_layout); + ComputationLayout* entry_computation_layout, + ChannelLayoutConstraints* channel_constraints = nullptr) { + LayoutAssignment layout_assignment( + entry_computation_layout, /*channel_constraints=*/channel_constraints); EXPECT_IS_OK(layout_assignment.Run(module).status()); } + + std::vector LayoutOf(HloModule* module, tensorflow::StringPiece name) { + auto minor_to_major = + FindInstruction(module, name)->shape().layout().minor_to_major(); + return std::vector(minor_to_major.begin(), minor_to_major.end()); + } }; TEST_F(LayoutAssignmentTest, ComputationLayout) { @@ -707,17 +715,10 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { LayoutUtil::MakeLayout({2, 1, 0})); AssignLayouts(module.get(), &computation_layout); - auto layout_of = [&](tensorflow::StringPiece name) { - return FindInstruction(module.get(), name) - ->shape() - .layout() - .minor_to_major(); - }; - - EXPECT_THAT(layout_of("gte0"), ElementsAre(0, 1, 2)); - EXPECT_THAT(layout_of("gte1a"), ElementsAre(1, 2, 0)); - EXPECT_THAT(layout_of("gte1b"), ElementsAre(2, 0, 1)); - EXPECT_THAT(layout_of("fresult"), ElementsAre(2, 1, 0)); + EXPECT_THAT(LayoutOf(module.get(), "gte0"), ElementsAre(0, 1, 2)); + EXPECT_THAT(LayoutOf(module.get(), "gte1a"), ElementsAre(1, 2, 0)); + EXPECT_THAT(LayoutOf(module.get(), "gte1b"), ElementsAre(2, 0, 1)); + EXPECT_THAT(LayoutOf(module.get(), "fresult"), ElementsAre(2, 1, 0)); EXPECT_THAT(FindInstruction(module.get(), "gte1") ->shape() .tuple_shapes(0) @@ -816,5 +817,44 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { "Unexpected bitcast operation seen during layout assignment")); } +TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { + // Pin non matching layouts to parameter and root. + const char* module_str = R"( + HloModule test_module + + ENTRY entry_computation { + param = (f32[2,2]) parameter(0) + gte = f32[2,2] get-tuple-element(param), index=0 + recv = (f32[2,2], u32[]) recv(), channel_id=1, sharding={maximal device=1} + ROOT recv-done = f32[2,2] recv-done(recv), channel_id=1, + sharding={maximal device=1} + send = (f32[2,2], u32[]) send(gte), channel_id=1, + sharding={maximal device=0} + send-done = () send-done(send), channel_id=1, sharding={maximal device=0} + } + )"; + + auto module = ParseHloString(module_str).ValueOrDie(); + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape()); + Shape param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})}); + TF_ASSERT_OK( + computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape( + param_shape)); + computation_layout.mutable_result_layout()->ResetLayout( + LayoutUtil::MakeLayout({1, 0})); + + ChannelLayoutConstraints channel_constraints; + AssignLayouts(module.get(), &computation_layout, &channel_constraints); + + EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(module.get(), "recv-done"), ElementsAre(1, 0)); + EXPECT_TRUE( + ShapeUtil::Equal(ShapeUtil::GetSubshape( + FindInstruction(module.get(), "send")->shape(), {0}), + ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index 21bca1d6beff5b2804531724b94b123d4523c173..f200a08a3cd7e33351ec4607d67d40e7ab28f3b9 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -32,7 +32,8 @@ static const BufferAllocation* kParameterAllocation = new BufferAllocation( LogicalBuffer::Color(0)); void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, - llvm_ir::IrArray* array) { + llvm_ir::IrArray* array, + const ShapeIndex& index) { BufferAllocation::Slice buffer_slice; if (hlo.opcode() == HloOpcode::kParameter) { // Parameters may alias with each other but may not alias with our temporary @@ -40,7 +41,7 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, buffer_slice = BufferAllocation::Slice(kParameterAllocation, 0, 0); } else { const std::set slices = - assignment_.GetAllSlices(&hlo, /*index=*/{}); + assignment_.GetAllSlices(&hlo, index); if (slices.empty() || slices.size() > 1) { // Skip HLOs which don't have a buffer assigned or for which the // buffer can't be determined statically. We cannot determine their @@ -137,16 +138,18 @@ llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer( // 2. Operands of users of the given hlo. // 3. Operands of the given hlo. // - // This set can be increased as we need. For now only consider top-level - // buffers (index = {}) not buffers nested within the instruction's - // operands/output which are not typically touched. + // This set can be increased as we need. std::vector worklist; auto add_buffers_to_worklist = [&worklist, &assignment](const HloInstruction* instruction) { - for (const LogicalBuffer* buffer : - assignment.GetSourceBuffers(instruction, /*index=*/{})) { - worklist.push_back(buffer); - } + ShapeUtil::ForEachSubshape( + instruction->shape(), + [&](const Shape& /*shape*/, const ShapeIndex& index) { + for (const LogicalBuffer* buffer : + assignment.GetSourceBuffers(instruction, index)) { + worklist.push_back(buffer); + } + }); }; for (HloInstruction* user : hlo.users()) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h index 5244ac61e56307857aca659854647bd6c3e991d7..fe9eab93aae95557e3ee27a64c09b78f37ac2348 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h @@ -38,7 +38,8 @@ class AliasAnalysis { // Augments IrArray with aliasing information. void AddAliasingInformationToIrArray(const HloInstruction& hlo, - llvm_ir::IrArray* array); + llvm_ir::IrArray* array, + const ShapeIndex& index = {}); private: // Returns a unique alias domain for this emitter. diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 7323abeb2077154f82828bcda3e90eb45a67138a..ea10cef49a4a9aa048b3e0ea443f052645c4912a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -29,9 +29,9 @@ limitations under the License. namespace xla { namespace llvm_ir { -static void Delinearize(std::vector* multidim, - llvm::Value* linear, const Shape& shape, - llvm::IRBuilder<>* ir_builder) { +void IrArray::Index::Delinearize(std::vector* multidim, + llvm::Value* linear, const Shape& shape, + llvm::IRBuilder<>* ir_builder) const { int64 divisor = 1; const Layout& layout = shape.layout(); for (int64 i = 0; i < layout.minor_to_major_size(); ++i) { @@ -48,10 +48,11 @@ static void Delinearize(std::vector* multidim, // useful because cuda-memcheck can't help us much in XLA: Most of our // memory lives in one big allocation, so cuda-memcheck can't detect // out-of-bounds accesses. - auto* quot = ir_builder->CreateUDiv(linear, ir_builder->getInt64(divisor)); + auto* quot = + ir_builder->CreateUDiv(linear, GetConstantWithIndexType(divisor)); if (i < layout.minor_to_major_size() - 1) { (*multidim)[dimension] = ir_builder->CreateURem( - quot, ir_builder->getInt64(size_of_current_dimension)); + quot, GetConstantWithIndexType(size_of_current_dimension)); } else { (*multidim)[dimension] = quot; } @@ -65,6 +66,8 @@ IrArray::Index::Index(llvm::Value* linear, const Shape& shape, linear_(linear), layout_(shape.layout()), dims_(shape.dimensions().begin(), shape.dimensions().end()) { + CHECK_NE(linear, nullptr); + index_type_ = linear->getType(); CHECK(LayoutUtil::HasLayout(shape)) << "Shape " << ShapeUtil::HumanStringWithLayout(shape) << " should have a layout."; @@ -77,6 +80,13 @@ IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, linear_(linear), layout_(shape.layout()), dims_(shape.dimensions().begin(), shape.dimensions().end()) { + if (size()) { + index_type_ = multidim_[0]->getType(); + } else { + CHECK_NE(linear_, nullptr); + index_type_ = linear_->getType(); + } + CHECK_NE(index_type_, nullptr); CHECK_EQ(shape.dimensions_size(), multidim.size()); CHECK(LayoutUtil::HasLayout(shape)) << "Shape " << ShapeUtil::HumanStringWithLayout(shape) @@ -88,6 +98,9 @@ IrArray::Index::Index(tensorflow::gtl::ArraySlice multidim, : multidim_(multidim.begin(), multidim.end()), layout_(shape.layout()), dims_(shape.dimensions().begin(), shape.dimensions().end()) { + CHECK_GT(multidim_.size(), 0); + index_type_ = multidim[0]->getType(); + CHECK_NE(index_type_, nullptr); CHECK_EQ(shape.dimensions_size(), multidim.size()); CHECK(LayoutUtil::HasLayout(shape)); } @@ -130,15 +143,15 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( CommonFactors(AsInt64Slice(input_shape.dimensions()), AsInt64Slice(output_shape.dimensions())); std::vector source_multidim_index( - ShapeUtil::Rank(input_shape), - llvm::UndefValue::get(builder->getInt64Ty())); + ShapeUtil::Rank(input_shape), llvm::UndefValue::get(index_type_)); // We compute the source indices in each common factor from only the target // indices in the same common factor. for (ssize_t k = common_factors.size() - 2; k >= 0; --k) { llvm::Value* logical_linear_index = Index(tensorflow::gtl::ArraySlice( multidim_, common_factors[k].second, - common_factors[k + 1].second - common_factors[k].second)) + common_factors[k + 1].second - common_factors[k].second), + index_type_) .Linearize( tensorflow::gtl::ArraySlice( AsInt64Slice(output_shape.dimensions()), @@ -150,9 +163,10 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( // linear index by each dimension size. for (int64 i = common_factors[k + 1].first - 1; i >= common_factors[k].first; --i) { - llvm::Value* divisor = builder->getInt64(input_shape.dimensions(i)); + llvm::Value* divisor = + GetConstantWithIndexType(input_shape.dimensions(i)); if (input_shape.dimensions(i) == 1) { - source_multidim_index[i] = builder->getInt64(0); + source_multidim_index[i] = GetConstantWithIndexType(0); } else if (i == common_factors[k].first) { source_multidim_index[i] = logical_linear_index; } else { @@ -168,14 +182,14 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) { return Index(source_multidim_index, linear(), input_shape); } - return Index(source_multidim_index); + return Index(source_multidim_index, index_type_); } IrArray::Index IrArray::Index::SourceIndexOfSlice( const Shape& shape, tensorflow::gtl::ArraySlice starts, tensorflow::gtl::ArraySlice strides, llvm::IRBuilder<>* builder) const { - Index source_index(multidim_.size()); + Index source_index(index_type_, multidim_.size()); for (int i = 0; i < multidim_.size(); ++i) { int64 stride = strides[i]; auto type = multidim_[i]->getType(); @@ -224,11 +238,12 @@ IrArray::Index IrArray::Index::SourceIndexOfBitcast( // the physical index of the element in the buffer. This is like Linearize, // but takes the layout into account. int64 scale = 1; - llvm::Value* linear_index = builder->getInt64(0); + llvm::Value* linear_index = GetConstantWithIndexType(0); for (auto dimension : LayoutUtil::MinorToMajor(shape)) { linear_index = builder->CreateAdd( linear_index, - builder->CreateMul(multidim_[dimension], builder->getInt64(scale), "", + builder->CreateMul(multidim_[dimension], + GetConstantWithIndexType(scale), "", /*HasNUW=*/true, /*HasNSW=*/true), "", /*HasNUW=*/true, /*HasNSW=*/true); scale *= shape.dimensions(dimension); @@ -252,7 +267,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( } if (linear_ == nullptr || !LayoutUtil::HasLayout(operand_shape) || !LayoutUtil::HasLayout(shape)) { - return Index(source_index); + return Index(source_index, index_type_); } // High-level idea: we can reuse the linear index if the broadcasted // dimensions are contiguous, and this part of the operation is a bitcast. @@ -274,7 +289,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( bool contiguous_broadcast_dimensions = max_broadcasted_dimension - min_broadcasted_dimension == rank - 1; if (!contiguous_broadcast_dimensions) { - return Index(source_index); + return Index(source_index, index_type_); } // Check if the mapped dimensions are a bitcast. std::vector operand_logical_to_physical = @@ -282,7 +297,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( for (int64 i = 0; i < rank; ++i) { if (operand_logical_to_physical[i] != logical_to_physical[dimension_mapping[i]] - min_broadcasted_dimension) { - return Index(source_index); + return Index(source_index, index_type_); } } llvm::Value* linear = linear_; @@ -291,7 +306,9 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( divisor *= shape.dimensions(LayoutUtil::Major(shape.layout(), i)); } if (divisor > 1) { - linear = builder->CreateUDiv(linear, builder->getInt64(divisor)); + linear = builder->CreateUDiv( + linear, + IrArray::Index(linear->getType()).GetConstantWithIndexType(divisor)); } if (min_broadcasted_dimension > 0) { int64 mod = 1; @@ -299,7 +316,9 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast( ++i) { mod *= shape.dimensions(LayoutUtil::Major(shape.layout(), i)); } - linear = builder->CreateURem(linear, builder->getInt64(mod)); + linear = builder->CreateURem( + linear, + IrArray::Index(linear->getType()).GetConstantWithIndexType(mod)); } return Index(source_index, linear, operand_shape); } @@ -309,12 +328,13 @@ llvm::Value* IrArray::Index::Linearize( llvm::IRBuilder<>* builder) const { // Each dimension is multiplied by the product of the sizes of all // earlier dimensions and added to the accumulator logical_linear_index. - llvm::Value* logical_linear_index = builder->getInt64(0); + llvm::Value* logical_linear_index = GetConstantWithIndexType(0); int64 multiplier = 1; for (ssize_t i = size() - 1; i >= 0; --i) { llvm::Value* addend = - builder->CreateMul((*this)[i], builder->getInt64(multiplier), "", + builder->CreateMul((*this)[i], GetConstantWithIndexType(multiplier), "", /*HasNUW=*/true, /*HasNSW=*/true); + addend = builder->CreateZExtOrTrunc(addend, index_type_); logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "", /*HasNUW=*/true, /*HasNSW=*/true); multiplier *= dimensions[i]; @@ -349,7 +369,8 @@ llvm::Value* IrArray::EmitArrayElementAddress( // index[i] with 0. However, setting index[i] to 0 here still allows LLVM to // produce better code in some cases. auto dim = shape_->dimensions(i); - actual_index.push_back(dim == 1 ? ir_builder->getInt64(0) : index[i]); + actual_index.push_back( + dim == 1 ? llvm::ConstantInt::get(index[i]->getType(), 0) : index[i]); } // "base_ptr_" has the type of "*" @@ -357,7 +378,9 @@ llvm::Value* IrArray::EmitArrayElementAddress( // should be computed by // // getelementptr base_ptr_, 0, most major index, ..., most minor index - std::vector gep_indices(1, ir_builder->getInt64(0)); + CHECK_GT(index.size(), 0); + std::vector gep_indices( + 1, llvm::ConstantInt::get(index[0]->getType(), 0)); for (int64 i = 0; i < LayoutUtil::MinorToMajor(*shape_).size(); ++i) { int64 dimension = LayoutUtil::Major(shape_->layout(), i); gep_indices.push_back(actual_index[dimension]); @@ -410,7 +433,9 @@ IrArray IrArray::CastToShape(const Shape& new_shape, llvm::IRBuilder<>* ir_builder) { Index new_index = index; new_index[which_dimension] = ir_builder->CreateAdd( - index[which_dimension], ir_builder->getInt64(addend), "", /*HasNUW=*/true, + index[which_dimension], + llvm::ConstantInt::get(index[which_dimension]->getType(), addend), "", + /*HasNUW=*/true, /*HasNSW=*/true); return new_index; } diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index 4c3195c29c859c9eef08e3f6531b059edbebfc47..4648c6d7ac089dbea7e660dd9889d557c8ad7318 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -53,18 +53,38 @@ class IrArray { // multidimensional index, which LLVM DCE can delete. class Index { public: - // Constructs an empty zero-dimensional index. - Index() {} - // Constructs an index of rank "size". Each dimension of the index is // initialized to "value". - explicit Index(size_t size, llvm::Value* value = nullptr) - : multidim_(size, value) {} + explicit Index(size_t size, llvm::Value* value) + : multidim_(size, value), index_type_(value->getType()) { + CHECK_NE(index_type_, nullptr); + } + + // Constructs an index of rank "size". Each dimension of the index is + // initialized to nullptr. + explicit Index(llvm::Type* index_ty, size_t size = 0) + : multidim_(size, nullptr), index_type_(index_ty) { + CHECK(index_ty->isIntegerTy()); + } // Constructs an index from multi-dimensional index "multidim". The linear // index is set to nullptr. - explicit Index(tensorflow::gtl::ArraySlice multidim) - : multidim_(multidim.begin(), multidim.end()) {} + explicit Index(tensorflow::gtl::ArraySlice multidim, + llvm::Type* index_ty = nullptr) + : multidim_(multidim.begin(), multidim.end()) { + if (size() == 0) { + index_type_ = index_ty; + } else { + index_type_ = (*this)[0]->getType(); + if (index_ty != nullptr) { + CHECK_EQ(index_type_, index_ty); + } + } + CHECK_NE(index_type_, nullptr); + CHECK(c_all_of(multidim, [&](llvm::Value* v) { + return index_type_ == v->getType(); + })); + } // Constructs an index from linear index "linear" and computes the // multi-dimensional index from "linear" and "shape". "ir_builder" is the IR @@ -154,6 +174,15 @@ class IrArray { llvm::Value* Linearize(tensorflow::gtl::ArraySlice dimensions, llvm::IRBuilder<>* builder) const; + llvm::Type* GetType() const { return index_type_; } + + llvm::Constant* GetConstantWithIndexType(int64 c) const { + // The LLVM function makes sure that the value can be represented by the + // specified type, see ConstantInt::ConstantInt(IntegerType *Ty, const + // APInt &V). + return llvm::ConstantInt::get(index_type_, c); + } + private: // Changing the multi-dimensional index invalidates the linear index. std::vector& multidim() { @@ -161,6 +190,9 @@ class IrArray { return multidim_; } + void Delinearize(std::vector* multidim, llvm::Value* linear, + const Shape& shape, llvm::IRBuilder<>* ir_builder) const; + std::vector multidim_; // These values are purely for efficiency; `multidim_` is enough to find the @@ -177,6 +209,8 @@ class IrArray { llvm::Value* linear_ = nullptr; Layout layout_; std::vector dims_; + + llvm::Type* index_type_; }; // Default constructor. Constructs an IrArray in a null status. diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc index 23d2d4e87d26f4988ebddcf20f5a27af6a7fe0d6..1f6e3c829f890d68aa251b101f0402c120a19d61 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -15,53 +15,57 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" -#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" namespace xla { -void KernelSupportLibrary::For( +Status KernelSupportLibrary::For( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, - const std::function& for_body_generator) { - If(ir_builder_->CreateICmpSLT(start, end), [&]() { - for_body_generator(start, /*is_first_iteration=*/true); - For(name, ir_builder_->CreateAdd(start, step), end, step, - [&](llvm::Value* iv) { for_body_generator(iv, false); }); + const std::function& for_body_generator) { + return If(ir_builder_->CreateICmpSLT(start, end), [&]() -> Status { + TF_RETURN_IF_ERROR(for_body_generator(start, /*is_first_iteration=*/true)); + return For(name, ir_builder_->CreateAdd(start, step), end, step, + [&](llvm::Value* iv) { return for_body_generator(iv, false); }); }); } -void KernelSupportLibrary::For( +Status KernelSupportLibrary::For( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, - const std::function& for_body_generator) { + const std::function& + for_body_generator) { if (peel_first_iteration) { - For(name, start, end, step, true, - [&](llvm::Value* indvar, bool is_first_iteration) { - for_body_generator(indvar, ir_builder_->getInt1(is_first_iteration)); - }); + return For(name, start, end, step, true, + [&](llvm::Value* indvar, bool is_first_iteration) -> Status { + return for_body_generator( + indvar, ir_builder_->getInt1(is_first_iteration)); + }); } else { std::unique_ptr loop = llvm_ir::ForLoop::EmitForLoop( name, start, end, step, ir_builder_, - /*prevent_unrolling=*/prevent_unrolling_, + /*unroll_mode=*/unroll_mode_, /*prevent_vectorization=*/prevent_vectorization_); ir_builder_->SetInsertPoint(&loop->GetBodyBasicBlock()->back()); - for_body_generator(loop->GetIndVarValue(), - /*is_first_iteration=*/ir_builder_->CreateICmpEQ( - loop->GetIndVarValue(), start)); + TF_RETURN_IF_ERROR( + for_body_generator(loop->GetIndVarValue(), + /*is_first_iteration=*/ir_builder_->CreateICmpEQ( + loop->GetIndVarValue(), start))); llvm_ir::SetToLastInsertPoint(loop->GetExitBasicBlock(), ir_builder_); + return Status::OK(); } } -void KernelSupportLibrary::If( - llvm::Value* condition, const std::function& true_block_generator, - const std::function& false_block_generator) { +Status KernelSupportLibrary::If( + llvm::Value* condition, const std::function& true_block_generator, + const std::function& false_block_generator) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(condition, "", ir_builder_); ir_builder_->SetInsertPoint(&if_data.true_block->back()); - true_block_generator(); + TF_RETURN_IF_ERROR(true_block_generator()); ir_builder_->SetInsertPoint(&if_data.false_block->back()); - false_block_generator(); + TF_RETURN_IF_ERROR(false_block_generator()); llvm_ir::SetToLastInsertPoint(if_data.after_block, ir_builder_); + return Status::OK(); } void KernelSupportLibrary::EmitAndCallOutlinedKernel( diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index 64b935bbf1fb9033cd2e1259b4639cd3780be711..6f7a9d94e3b9e59b2dfe12b9673335a904ae78b6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -30,13 +31,14 @@ namespace xla { class KernelSupportLibrary { public: // `ir_builder` is the llvm::IRBuilder instance used to generate LLVM IR. - // If `prevent_unrolling` is true then unrolling is explicitly disabled on - // every loop generated by this instance of KernelSupportLibrary. - explicit KernelSupportLibrary(llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling = true, - bool prevent_vectorization = true) + // `unroll_mode` specifies the desired LLVM unrolling behavior for every loop + // generated by this instance of KernelSupportLibrary. + explicit KernelSupportLibrary( + llvm::IRBuilder<>* ir_builder, + llvm_ir::UnrollMode unroll_mode = llvm_ir::UnrollMode::kNoUnroll, + bool prevent_vectorization = true) : ir_builder_(ir_builder), - prevent_unrolling_(prevent_unrolling), + unroll_mode_(unroll_mode), prevent_vectorization_(prevent_vectorization) {} // Generates the following control flow structure: @@ -46,19 +48,41 @@ class KernelSupportLibrary { // for (i64 i = `start` + `step`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`; // } - void For( + Status For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, + const std::function& for_body_generator); + + void ForReturnVoid( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& - for_body_generator); + for_body_generator) { + CHECK_EQ(Status::OK(), + For(name, start, end, step, + [&](llvm::Value* ind_var, bool is_first_iteration) -> Status { + for_body_generator(ind_var, is_first_iteration); + return Status::OK(); + })); + } + + Status For(tensorflow::StringPiece name, int64 start, int64 end, int64 step, + const std::function& + for_body_generator) { + return For(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); + } - void For( + void ForReturnVoid( tensorflow::StringPiece name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - For(name, /*start=*/ir_builder_->getInt64(start), - /*end=*/ir_builder_->getInt64(end), - /*step=*/ir_builder_->getInt64(step), for_body_generator); + ForReturnVoid(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); } // Generates the following control flow structure if `peel_first_iteration` is @@ -75,46 +99,102 @@ class KernelSupportLibrary { // for (i64 i = `start`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, // /*is_first_iteration=*/,(i != `start`))`; - void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, - llvm::Value* step, bool peel_first_iteration, - const std::function& - for_body_generator); - - void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, - int64 step, bool peel_first_iteration, - const std::function& - for_body_generator) { - For(name, /*start=*/start, /*end=*/end, - /*step=*/ir_builder_->getInt64(step), peel_first_iteration, - for_body_generator); + Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, bool peel_first_iteration, + const std::function& + for_body_generator); + + void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start, + llvm::Value* end, llvm::Value* step, + bool peel_first_iteration, + const std::function& + for_body_generator) { + TF_CHECK_OK(For( + name, start, end, step, peel_first_iteration, + [&](llvm::Value* ind_var, llvm::Value* is_first_iteration) -> Status { + for_body_generator(ind_var, is_first_iteration); + return Status::OK(); + })); + } + + Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + int64 step, bool peel_first_iteration, + const std::function& + for_body_generator) { + return For(name, /*start=*/start, /*end=*/end, + /*step=*/llvm::ConstantInt::get(start->getType(), step), + peel_first_iteration, for_body_generator); } - void For( + void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start, + llvm::Value* end, int64 step, bool peel_first_iteration, + const std::function& + for_body_generator) { + ForReturnVoid(name, /*start=*/start, /*end=*/end, + /*step=*/llvm::ConstantInt::get(start->getType(), step), + peel_first_iteration, for_body_generator); + } + + Status For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, + const std::function& for_body_generator) { + return For(name, start, end, step, + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) -> Status { + return for_body_generator(indvar); + }); + } + + void ForReturnVoid( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { - For(name, start, end, step, - /*peel_first_iteration=*/false, - [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); }); + ForReturnVoid(name, start, end, step, + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) { + return for_body_generator(indvar); + }); + } + + Status For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + int64 step, + const std::function& for_body_generator) { + return For(name, start, end, llvm::ConstantInt::get(start->getType(), step), + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) -> Status { + return for_body_generator(indvar); + }); } - void For( + void ForReturnVoid( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, int64 step, const std::function& for_body_generator) { - For(name, start, end, ir_builder_->getInt64(step), - /*peel_first_iteration=*/false, - [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); }); + ForReturnVoid(name, start, end, + llvm::ConstantInt::get(start->getType(), step), + for_body_generator); + } + + Status For( + tensorflow::StringPiece name, int64 start, int64 end, int64 step, + const std::function& for_body_generator) { + return For(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); } - void For( + void ForReturnVoid( tensorflow::StringPiece name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - For(name, /*start=*/ir_builder_->getInt64(start), - /*end=*/ir_builder_->getInt64(end), - /*step=*/ir_builder_->getInt64(step), for_body_generator); + ForReturnVoid(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); } // Generates the following control flow structure: @@ -123,9 +203,25 @@ class KernelSupportLibrary { // `true_block_generator()`; // else // `false_block_generator()`; - void If(llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = []() {}); + Status If(llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = + []() -> Status { return Status::OK(); }); + + void IfReturnVoid(llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() { + }) { + TF_CHECK_OK(If(condition, + [&]() { + true_block_generator(); + return Status::OK(); + }, + [&]() { + false_block_generator(); + return Status::OK(); + })); + } using ArgumentVector = tensorflow::gtl::ArraySlice; @@ -183,7 +279,7 @@ class KernelSupportLibrary { private: llvm::IRBuilder<>* ir_builder_; - bool prevent_unrolling_; + llvm_ir::UnrollMode unroll_mode_; bool prevent_vectorization_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index 497b48ff227d7d1f158080529372df44b6932b24..c9ae7d3afd5cdc21157732f6d0dfa824268e86bd 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -34,7 +34,7 @@ namespace llvm_ir { ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, - llvm::Value* step, bool prevent_unrolling, + llvm::Value* step, UnrollMode unroll_mode, bool prevent_vectorization) : prefix_(std::string(prefix)), suffix_(std::string(suffix)), @@ -42,15 +42,15 @@ ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, end_index_(end_index), step_(step), insert_before_bb_(nullptr), - prevent_unrolling_(prevent_unrolling), + unroll_mode_(unroll_mode), prevent_vectorization_(prevent_vectorization) {} /* static */ std::unique_ptr ForLoop::EmitForLoop( tensorflow::StringPiece prefix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling, bool prevent_vectorization) { + UnrollMode unroll_mode, bool prevent_vectorization) { std::unique_ptr loop(new ForLoop(prefix, /*suffix=*/"", start_index, - end_index, step, prevent_unrolling, + end_index, step, unroll_mode, prevent_vectorization)); loop->Emit(ir_builder); return loop; @@ -97,7 +97,7 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { ir_builder->SetInsertPoint(&func->getEntryBlock(), func->getEntryBlock().getFirstInsertionPt()); llvm::Value* indvar_address = - ir_builder->CreateAlloca(ir_builder->getInt64Ty(), nullptr, + ir_builder->CreateAlloca(start_index_->getType(), nullptr, AsStringRef(GetQualifiedName("invar_address"))); // Preheader basic block. @@ -147,11 +147,12 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { std::vector ForLoop::GetLoopMetadata( llvm::IRBuilder<>* ir_builder) { const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable"; + const char* const kLlvmLoopUnrollFullMDName = "llvm.loop.unroll.full"; const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable"; llvm::LLVMContext* ctx = &start_index_->getContext(); std::vector result; - if (prevent_unrolling_) { + if (unroll_mode_ == xla::llvm_ir::UnrollMode::kNoUnroll) { result.push_back(llvm::MDNode::get( *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)})); } @@ -162,6 +163,10 @@ std::vector ForLoop::GetLoopMetadata( llvm::ConstantAsMetadata::get(ir_builder->getFalse())})); } + if (unroll_mode_ == xla::llvm_ir::UnrollMode::kFullyUnroll) { + result.push_back(llvm::MDNode::get( + *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollFullMDName)})); + } return result; } @@ -178,25 +183,25 @@ llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name, std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, - bool prevent_unrolling, + UnrollMode unroll_mode, bool prevent_vectorization) { - return AddLoop(suffix, start_index, end_index, ir_builder_->getInt64(1), - prevent_unrolling, prevent_vectorization); + return AddLoop(suffix, start_index, end_index, GetConstantWithIndexType(1), + unroll_mode, prevent_vectorization); } std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* stride, - bool prevent_unrolling, + UnrollMode unroll_mode, bool prevent_vectorization) { if (inner_loop_body_bb_ != nullptr) { // Create this loop inside the previous one. ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); } std::unique_ptr loop(new ForLoop( - /*prefix=*/name_, suffix, start_index, end_index, stride, - prevent_unrolling, prevent_vectorization)); + /*prefix=*/name_, suffix, start_index, end_index, stride, unroll_mode, + prevent_vectorization)); loop->Emit(ir_builder_); if (outer_loop_preheader_bb_ == nullptr) { @@ -215,23 +220,23 @@ std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, std::unique_ptr ForLoopNest::AddLoop(int64 start_index, int64 end_index, tensorflow::StringPiece suffix, - bool prevent_unrolling, + UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); - return AddLoop(suffix, ir_builder_->getInt64(start_index), - ir_builder_->getInt64(end_index), prevent_unrolling, + return AddLoop(suffix, GetConstantWithIndexType(start_index), + GetConstantWithIndexType(end_index), unroll_mode, prevent_vectorization); } std::unique_ptr ForLoopNest::AddLoop(int64 start_index, int64 end_index, int64 stride, tensorflow::StringPiece suffix, - bool prevent_unrolling, + UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); - return AddLoop(suffix, ir_builder_->getInt64(start_index), - ir_builder_->getInt64(end_index), - ir_builder_->getInt64(stride), prevent_unrolling, + return AddLoop(suffix, GetConstantWithIndexType(start_index), + GetConstantWithIndexType(end_index), + GetConstantWithIndexType(stride), unroll_mode, prevent_vectorization); } @@ -245,7 +250,7 @@ IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions( const Shape& shape, tensorflow::gtl::ArraySlice dimensions, tensorflow::StringPiece suffix) { - llvm_ir::IrArray::Index index(shape.dimensions_size(), nullptr); + llvm_ir::IrArray::Index index(index_type_, shape.dimensions_size()); for (int64 dimension : dimensions) { std::unique_ptr loop = AddLoop( /*start_index=*/0, diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index d915f95db134918a173a9711936bb1e2f1ea0d95..0dd5b9d3b2656af68f76c2adfcb1f3a1385eeb91 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -34,6 +34,12 @@ limitations under the License. namespace xla { namespace llvm_ir { +enum class UnrollMode { + kDefaultUnroll, + kFullyUnroll, + kNoUnroll, +}; + // A class for constructing a for-loop in LLVM IR. class ForLoop { public: @@ -69,12 +75,13 @@ class ForLoop { // LLVM IR. If non-empty, it is prepended to the name of the induction // variable value and each basic block created for the loop. // - // If `prevent_unrolling` is true then emit metadata that directs LLVM to not - // unroll the generated loop. + // `unroll_mode` specifies the desired LLVM unrolling behavior for generated + // loop. static std::unique_ptr EmitForLoop( tensorflow::StringPiece prefix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling = false, bool prevent_vectorization = false); + UnrollMode unroll_mode = llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // The names of the blocks follow LLVM's conventions. Control flow amongst the // blocks for the example C code looks like: @@ -128,7 +135,7 @@ class ForLoop { ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, - bool prevent_unrolling, bool prevent_vectorization); + UnrollMode unroll_mode, bool prevent_vectorization); // Emit the loop at the insert point of the builder. void Emit(llvm::IRBuilder<>* ir_builder); @@ -161,7 +168,7 @@ class ForLoop { llvm::BasicBlock* body_bb_; llvm::BasicBlock* exit_bb_; llvm::Value* indvar_; - bool prevent_unrolling_; + UnrollMode unroll_mode_; bool prevent_vectorization_; TF_DISALLOW_COPY_AND_ASSIGN(ForLoop); @@ -170,46 +177,52 @@ class ForLoop { // A simple class for constructing nested for-loops. class ForLoopNest { public: - explicit ForLoopNest(llvm::IRBuilder<>* ir_builder) - : ForLoopNest(/*name=*/"", ir_builder) {} + explicit ForLoopNest(llvm::IRBuilder<>* ir_builder, + llvm::Type* index_ty = nullptr) + : ForLoopNest(/*name=*/"", ir_builder) { + SetIndexType(index_ty); + } - ForLoopNest(tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder) + ForLoopNest(tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder, + llvm::Type* index_ty = nullptr) : name_(std::string(name)), outer_loop_preheader_bb_(nullptr), outer_loop_exit_bb_(nullptr), inner_loop_body_bb_(nullptr), - ir_builder_(ir_builder) {} + ir_builder_(ir_builder) { + SetIndexType(index_ty); + } // Adds a loop to the nest. If no loop has been added yet then emit a loop at // the current insert point of the given builder. If one or more loops have - // been added then emit loop inside the body of the last added loop. If - // prevent_unrolling is true, then metadata is emitting directing LLVM to not - // unroll this loop. - std::unique_ptr AddLoop(tensorflow::StringPiece suffix, - llvm::Value* start_index, - llvm::Value* end_index, llvm::Value* stride, - bool prevent_unrolling = false, - bool prevent_vectorization = false); + // been added then emit loop inside the body of the last added loop. + // unroll_mode is used to emit metadata that controls LLVM unrolling. + std::unique_ptr AddLoop( + tensorflow::StringPiece suffix, llvm::Value* start_index, + llvm::Value* end_index, llvm::Value* stride, + UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. - std::unique_ptr AddLoop(tensorflow::StringPiece suffix, - llvm::Value* start_index, - llvm::Value* end_index, - bool prevent_unrolling = false, - bool prevent_vectorization = false); + std::unique_ptr AddLoop( + tensorflow::StringPiece suffix, llvm::Value* start_index, + llvm::Value* end_index, + UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // A convenient wrapper of the other flavor of AddLoop. The given start and // end index are constant. - std::unique_ptr AddLoop(int64 start_index, int64 end_index, - int64 stride, tensorflow::StringPiece suffix, - bool prevent_unrolling = false, - bool prevent_vectorization = false); + std::unique_ptr AddLoop( + int64 start_index, int64 end_index, int64 stride, + tensorflow::StringPiece suffix, + UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. - std::unique_ptr AddLoop(int64 start_index, int64 end_index, - tensorflow::StringPiece suffix, - bool prevent_unrolling = false, - bool prevent_vectorization = false); + std::unique_ptr AddLoop( + int64 start_index, int64 end_index, tensorflow::StringPiece suffix, + UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // Add loops to iterate through the indices within the specified // shape. The returned index collects the induction variables of the @@ -245,6 +258,14 @@ class ForLoopNest { llvm::BasicBlock* GetInnerLoopBodyBasicBlock() { return inner_loop_body_bb_; } private: + void SetIndexType(llvm::Type* index_ty) { + index_type_ = index_ty == nullptr ? ir_builder_->getInt64Ty() : index_ty; + } + + llvm::Constant* GetConstantWithIndexType(int64 c) const { + return llvm::ConstantInt::get(index_type_, c); + } + // Human-friendly name of the loop nest. string name_; @@ -259,6 +280,8 @@ class ForLoopNest { llvm::IRBuilder<>* ir_builder_; + llvm::Type* index_type_; + TF_DISALLOW_COPY_AND_ASSIGN(ForLoopNest); }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index ff64da87e9c9acf8a9d7ff87d3b1be7a9e9106bb..e61a2fd12de71709dfb1b5a3b736c461d6072c1e 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -193,6 +193,10 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, // An Opaque is like a void*, use i8*. case OPAQUE: return llvm::Type::getInt8PtrTy(module->getContext()); + case TOKEN: + // Tokens do not have a physical representation, but the compiler needs + // some placeholder type, so use int8*. + return llvm::Type::getInt8PtrTy(module->getContext()); default: LOG(FATAL) << "unsupported type " << element_type; } @@ -245,167 +249,16 @@ StatusOr DecodeSelfDescribingShapeConstant(const void* shape_ptr, return shape; } -namespace { - -// Recursively construct a multidimensional LLVM constant which represents the -// given literal. The minor-to-major dimension ordering in the constant matches -// that of the literal. For example, given a [2 x 3 x 4] Literal (dimension 0 -// has size 4, dimension 1 has size 3, etc) of primitive type F32 with a -// minor_to_major value of [2, 1, 0] (column major), a LLVM constant of type -// [4 x [3 x [2 x float]] will be returned. -// -// multi_index is a multidimensional index into the array. dimension_index is an -// index into the minor_to_major field in the literal shape. This determines -// which dimension is iterated over in this level of the recursion. Dimensions -// are iterated from most major down to most minor (highest dimension_index -// value down to zero). -llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, - std::vector* multi_index, - llvm::Module* module) { - const Shape& shape = literal.shape(); - llvm::Type* ir_element_type = - llvm_ir::PrimitiveTypeToIrType(shape.element_type(), module); - if (dimension_index == -1) { - // Base case of the recursion. Index into the data field of the protobuf - // with the multi index. - llvm::Constant* value; - switch (shape.element_type()) { - case PRED: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get(*multi_index)); - break; - case U8: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get(*multi_index)); - break; - case S32: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get(*multi_index)); - break; - case U32: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get(*multi_index)); - break; - case S64: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get(*multi_index)); - break; - case U64: - value = llvm::ConstantInt::get(ir_element_type, - literal.Get(*multi_index)); - break; - case F32: - value = llvm::ConstantFP::get(ir_element_type, - literal.Get(*multi_index)); - break; - case BF16: - value = llvm::ConstantInt::get( - ir_element_type, - tensorflow::bit_cast(literal.Get(*multi_index))); - break; - case F16: - value = llvm::ConstantFP::get( - ir_element_type, - static_cast(literal.Get(*multi_index))); - break; - case F64: - value = llvm::ConstantFP::get(ir_element_type, - literal.Get(*multi_index)); - break; - case C64: { - complex64 x = literal.Get(*multi_index); - value = llvm::ConstantStruct::get( - static_cast(ir_element_type), - llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module), - x.real()), - llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module), - x.imag())); - break; - } - default: - LOG(FATAL) << "unsupported type " << shape.element_type(); - } - return value; - } - - // The dimension index starts at the one less than the rank of the array and - // decrements with each recursive call. We want to iterate through the - // dimensions in major-to-minor order as we recurse so just index into - // minor_to_major to get the dimension number for this level of the recursion. - int64 dimension = LayoutUtil::Minor(shape.layout(), dimension_index); - - // Recursively call LiteralToConstant to construct subarrays for the - // more-minor dimensions. Gather the subarrays into a vector for bundling into - // a new (higher-dimensional) ConstantArray. - std::vector elements; - for (int64 i = 0; i < shape.dimensions(dimension); ++i) { - (*multi_index)[dimension] = i; - elements.push_back( - LiteralToConstant(literal, dimension_index - 1, multi_index, module)); - } - - llvm::Type* element_type; - if (elements.empty()) { - element_type = ir_element_type; - for (int i = 0; i < dimension_index; ++i) { - int64 index = LayoutUtil::Minor(shape.layout(), i); - element_type = - llvm::ArrayType::get(element_type, shape.dimensions(index)); - } - } else { - element_type = elements[0]->getType(); - } - llvm::ArrayType* aggregate_type = - llvm::ArrayType::get(element_type, shape.dimensions(dimension)); - return llvm::ConstantArray::get(aggregate_type, elements); -} - -template -llvm::Constant* GetConstantDataArray(const Literal& literal, - llvm::Module* module) { - const T* data = static_cast(literal.untyped_data()); - int64 num_elements = literal.size_bytes() / sizeof(T); - return llvm::ConstantDataArray::get(module->getContext(), - llvm::makeArrayRef(data, num_elements)); -} - -} // namespace - llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, llvm::Module* module) { const Shape& shape = literal.shape(); - // TODO(b/29904935): We can get rid of this switch by exposing a - // ConstantDataArray factory method that takes a llvm::Type and a StringRef. - switch (shape.element_type()) { - case U64: - return GetConstantDataArray(literal, module); - case U32: - return GetConstantDataArray(literal, module); - case U8: - return GetConstantDataArray(literal, module); - case S64: - return GetConstantDataArray(literal, module); - case S32: - return GetConstantDataArray(literal, module); - case F64: - return GetConstantDataArray(literal, module); - case F32: - return GetConstantDataArray(literal, module); - case BF16: - case F16: - return GetConstantDataArray(literal, module); - case PRED: - return GetConstantDataArray(literal, module); - // TODO(b/29904935): Also use ConstantDataArray for complex numbers. - case C64: { - int64 dimensions = ShapeUtil::Rank(shape); - std::vector multi_index(dimensions, 0); - return LiteralToConstant(literal, /*dimension_index=*/dimensions - 1, - &multi_index, module); - } - default: - LOG(FATAL) << "unsupported type " << shape.element_type(); - } + llvm::Type* type = shape.element_type() == C64 + ? llvm::Type::getFloatTy(module->getContext()) + : PrimitiveTypeToIrType(shape.element_type(), module); + const char* data = static_cast(literal.untyped_data()); + uint64 num_elements = literal.size_bytes() * 8 / GetSizeInBits(type); + return llvm::ConstantDataArray::getRaw( + llvm::StringRef(data, literal.size_bytes()), num_elements, type); } llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index dc2934a34c23f8229947210cacc9863d47c2ea55..e8b0605b9d75677b34f0973d88d269a5795b7629 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -90,11 +90,12 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, } std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name) { + tensorflow::StringPiece loop_name, llvm::Type* index_type) { + CHECK_NE(index_type, nullptr); if (ShapeUtil::IsScalar(shape_)) { // No loop needed, so set exit_bb_ to nullptr. exit_bb_ = nullptr; - return {IrArray::Index()}; + return {IrArray::Index(index_type)}; } // Create loop nest with one for-loop for each dimension of the target shape. @@ -102,7 +103,7 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( // class so emit loops in order from most-major dimension down to most-minor // dimension (of the target shape). ForLoopNest loop_nest(loop_name, ir_builder_); - IrArray::Index array_index(shape_.dimensions_size()); + IrArray::Index array_index(index_type, shape_.dimensions_size()); for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) { int64 dimension = LayoutUtil::Major(shape_.layout(), i); std::unique_ptr loop = loop_nest.AddLoop( @@ -125,9 +126,14 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( return {array_index}; } -Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) { +Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name, + llvm::Type* index_type) { + if (index_type == nullptr) { + index_type = ir_builder_->getInt64Ty(); + } + for (const IrArray::Index& array_index : - EmitIndexAndSetExitBasicBlock(loop_name)) { + EmitIndexAndSetExitBasicBlock(loop_name, index_type)) { TF_RETURN_IF_ERROR(body_emitter_(array_index)); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index b70d28ecd3033eb26629718e50ce48f39b162273..6be1c2fba2cbd78a02865901ef8c5b7e2b2a74e6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -65,13 +65,16 @@ class LoopEmitter { // specifies the element, will return multiple indices if the loop is // unrolled. std::vector EmitIndexAndSetExitBasicBlock() { - return EmitIndexAndSetExitBasicBlock(/*loop_name=*/""); + return EmitIndexAndSetExitBasicBlock(/*loop_name=*/"", + ir_builder_->getInt64Ty()); } + virtual std::vector EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name); + tensorflow::StringPiece loop_name, llvm::Type* index_type); // Emits a complete loop nest for every element in the given shape. - Status EmitLoop(tensorflow::StringPiece loop_name = ""); + Status EmitLoop(tensorflow::StringPiece loop_name = "", + llvm::Type* index_type = nullptr); protected: // An IR emitter that generates the loop body. diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.cc b/tensorflow/compiler/xla/service/llvm_ir/ops.cc index dacc54742c0897bbd92315f1e33a484aae56bb7f..3b298f4746d6177da52ba0227705d07fbeba5c19 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ops.cc @@ -45,7 +45,7 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( // Read start indices from start_indices_generator. const int64 rank = ShapeUtil::Rank(output_shape); - IrArray::Index start_index(rank); + IrArray::Index start_index(ir_builder->getInt64Ty(), rank); for (int64 i = 0; i < rank; ++i) { IrArray::Index dim_index({ir_builder->getInt64(i)}); TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(dim_index)); @@ -79,7 +79,7 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( // // output_index[dim] = start_index[dim] + update_index[dim] // - IrArray::Index output_index(rank); + IrArray::Index output_index(start_index.GetType(), rank); for (int64 i = 0; i < rank; ++i) { llvm::Value* start_index0 = ir_builder->CreateSExtOrBitCast( start_index[i], update_index[i]->getType()); diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 1d9c9e0678765a779ec94e578e0e6f69d46b80de..53efc30c3653879709fceae3dcdd4f679740f622 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/platform_util.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -155,7 +154,8 @@ StatusOr> LocalService::CompileExecutable( for (int i = 0; i < argument_layouts.size(); ++i) { const Shape& argument_shape = *argument_layouts[i]; - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape)); + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(argument_shape)); if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) { tensorflow::gtl::optional metadata = ParameterMetadata(computation, /*parameter_number=*/i); @@ -179,8 +179,8 @@ StatusOr> LocalService::CompileExecutable( } } if (build_options.result_layout() != nullptr) { - TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout( - *build_options.result_layout(), program_shape.result())); + TF_RETURN_IF_ERROR(ValidateResultShape(*build_options.result_layout(), + program_shape.result())); } ExecutionOptions execution_options = @@ -190,6 +190,9 @@ StatusOr> LocalService::CompileExecutable( std::unique_ptr module_config, CreateModuleConfig(program_shape, argument_layouts, &execution_options)); + VLOG(3) << "Computation Layout: " + << module_config->entry_computation_layout().ToString(); + TF_ASSIGN_OR_RETURN( se::StreamExecutor * executor, execute_backend_->stream_executor(build_options.device_ordinal())); diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc new file mode 100644 index 0000000000000000000000000000000000000000..79b5a442aa0ecd0f67ffe4dad50465627d8913fd --- /dev/null +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -0,0 +1,359 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/multi_output_fusion.h" + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +StatusOr MultiOutputFusion::Run(HloModule* module) { + bool changed = false; + + for (auto* computation : module->MakeNonfusionComputations()) { + computation_ = computation; + RecomputeReachability(); + candidates_.clear(); + candidates_index_.clear(); + all_fusion_candidates_.clear(); + + int64 index = 0; + for (auto it : computation_->MakeInstructionPostOrder()) { + candidates_.emplace_back(it); + InsertOrDie(&candidates_index_, it, index++); + } + + // Create the initial candidate list for each Node. + for (auto& node : candidates_) { + HloInstruction* instruction = node.hlo; + int64 instruction_id = get_candidate_id(instruction); + FusionCandidate& instr_node = candidates_[instruction_id]; + if (!IsFusible(instruction)) { + continue; + } + all_fusion_candidates_.push_back(instruction); + + std::vector candidates; + tensorflow::gtl::FlatSet candidates_set; + VLOG(10) << "Looking at instruction: " << instruction->name(); + for (auto operand : instruction->operands()) { + // Filter out the non-interesting instructions -- they + // will not generate the savings. + if (!IsProfitableOperand(operand)) { + VLOG(10) << "Operand not profitable: " << operand->name(); + continue; + } + VLOG(10) << "Operand profitable: " << operand->name(); + for (auto user : operand->users()) { + VLOG(10) << "User: " << user->name(); + if (user == instruction || !IsFusible(user)) { + VLOG(10) << "User is not fusible, or is the instruction itself: " + << user->name(); + continue; + } + int64 user_id = get_candidate_id(user); + if (is_connected(instruction, user)) { + VLOG(10) << "User is connected: " << user->name(); + continue; + } + if (instruction_id < user_id && + user->opcode() == HloOpcode::kFusion) { + VLOG(10) << "User ID for user: " << user->name() << " is " + << user_id << " which is higher than " << instruction_id; + continue; + } + if (!LegalToFuse(instruction, user)) { + VLOG(10) << "User not legal to fuse: " << user->name(); + continue; + } + if (candidates_set.insert(user).second) { + VLOG(10) << "User added to candidate list: " << user->name(); + candidates.push_back(user); + } + } + } + + // Iterate over candidates rather than candidates_set to avoid + // nondeterminism. + for (auto candidate : candidates) { + int64 profit = GetProfit(instruction, candidate); + if (profit > 0) { + FusionCandidate& candidate_node = + candidates_[get_candidate_id(candidate)]; + instr_node.fusibles.emplace_back(candidate, profit); + candidate_node.fusibles.emplace_back(instruction, profit); + worklist_.emplace(instruction, candidate, profit); + } + } + } + if (Perform()) { + changed = true; + } + } + return changed; +} + +HloInstruction* MultiOutputFusion::Fuse(HloInstruction* instr1, + HloInstruction* instr2) { + HloInstruction* remaining = instr1; + HloInstruction* fused = instr2; + // Make sure that if only one of the instructions is a fusion, or if only one + // of the instructions is a multi-output fusion, it's what will be fused into. + // + // An invariant is that no bitcast nodes will show up in the middle of a + // fusion node. This invariant must hold in order for us to lower it. Given + // that, we require that during multi-output fusion, a fusion node ending with + // bitcast to preserve its structure as a nested fusion instead being + // merged and flattened. + if (fused->opcode() == HloOpcode::kFusion && + fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) { + std::swap(remaining, fused); + } + if (fused->IsMultiOutputFusion()) { + std::swap(remaining, fused); + } + + if (fused->opcode() == HloOpcode::kFusion && + fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) { + remaining->MergeFusionInstructionIntoMultiOutput(fused); + } else { + if (remaining->opcode() == HloOpcode::kFusion && + remaining->fused_expression_root()->opcode() == HloOpcode::kBitcast) { + auto parent_computation = remaining->parent(); + // Create a nested fusion node. + auto remaining_nested_fused = + parent_computation->AddInstruction(HloInstruction::CreateFusion( + remaining->shape(), HloInstruction::FusionKind::kLoop, + remaining)); + TF_CHECK_OK(parent_computation->ReplaceInstruction( + remaining, remaining_nested_fused)); + remaining = remaining_nested_fused; + } + remaining->FuseInstructionIntoMultiOutput(fused); + } + + return remaining; +} + +bool MultiOutputFusion::IsProfitableOperand(HloInstruction* instr) { + // kConstant instruction will not have memory reads, so it won't be a profit + // source. Skip them. + if (instr->opcode() == HloOpcode::kConstant && + ShapeUtil::IsEffectiveScalar(instr->shape())) { + return false; + } + // We don't target to fuse producer/consumer instructions -- this should + // be taken care of by the instruction_fusion pass. If instr has only + // one user, it will not have sibling instructions. We won't consider it. + if (instr->user_count() < 2) { + return false; + } + return true; +} + +void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) { + HloInstruction* fusion = instr1; + HloInstruction* fused = instr2; + if (is_fused(instr1)) { + fusion = instr2; + fused = instr1; + } + + // Insert the newly created instruction (if any), to candidates_. + for (auto use : fusion->users()) { + if (candidates_index_.find(use) == candidates_index_.end()) { + int64 index = candidates_.size(); + candidates_.emplace_back(use); + InsertOrDie(&candidates_index_, use, index++); + } + } + FusionCandidate& fusion_node = candidates_[get_candidate_id(fusion)]; + FusionCandidate& fused_node = candidates_[get_candidate_id(fused)]; + + // Update the reachability graph. + UpdateReachability(fusion, fused, all_fusion_candidates_, + [this](HloInstruction* instr) { return is_fused(instr); }); + + // Update the fusible list for fusion. Variable new_fusibles keeps + // track of the new or changed entries. + std::vector> new_fusibles; + tensorflow::gtl::FlatSet in_list; + auto it = fusion_node.fusibles.begin(); + while (it != fusion_node.fusibles.end()) { + HloInstruction* instr = it->first; + if (is_fused(instr) || is_connected(fusion, instr)) { + it = fusion_node.fusibles.erase(it); + continue; + } + in_list.insert(instr); + int64 profit = GetProfit(instr, fusion); + if (profit > it->second) { + it->second = profit; + new_fusibles.emplace_back(instr, profit); + } + ++it; + } + + // Fused_node has been fused into fusion_node. Take the fusion candidates + // (fusibles) from fused_nodes and add them to the fusion_node's. Filter + // out those fusibles that no longer valid (or already in the list). + for (const auto& it : fused_node.fusibles) { + HloInstruction* instr = it.first; + if (instr == fusion || is_fused(instr) || is_connected(fusion, instr)) { + continue; + } + if (in_list.count(instr) > 0) { + continue; + } + int64 profit = GetProfit(instr, fusion); + fusion_node.fusibles.emplace_back(instr, profit); + new_fusibles.emplace_back(instr, profit); + } + fused_node.fusibles.clear(); + + // Update the worklist_. + for (auto it : new_fusibles) { + worklist_.emplace(fusion, it.first, it.second); + } +} + +bool MultiOutputFusion::LegalToFuse(HloInstruction* instr1, + HloInstruction* instr2) { + if (instr1 == instr2) { + return false; + } + if (instr1->opcode() != HloOpcode::kFusion) { + return false; + } + + // Fusing nodes with 0 user makes no sense and the rest of the implementation + // doesn't support it either. + if (instr1->user_count() == 0 || instr2->user_count() == 0) { + return false; + } + + // Check if the users of multioutput fusion is not a get-tuple-element. + // If this is the case, we bail out because the transformation assumes + // the users are get-tuple-element. + auto multioutput_user_is_not_gte = [](HloInstruction* instr) { + if (!instr->IsMultiOutputFusion()) { + return false; + } + for (auto user : instr->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + return true; + } + } + return false; + }; + if (multioutput_user_is_not_gte(instr1) || + multioutput_user_is_not_gte(instr2)) { + return false; + } + + if (is_connected(instr1, instr2)) { + return false; + } + if (!ShapesCompatibleForFusion(instr1, instr2)) { + return false; + } + + return true; +} + +void MultiOutputFusion::RecomputeReachability() { + reachability_ = computation_->ComputeReachability(); +} + +void MultiOutputFusion::UpdateReachability( + HloInstruction* instr1, HloInstruction* instr2, + tensorflow::gtl::ArraySlice instrs_to_update, + const std::function& skip) { + for (auto instr : instrs_to_update) { + if (skip != nullptr && skip(instr)) { + continue; + } + if (reachability_->IsReachable(instr2, instr) && + reachability_->IsReachable(instr1, instr)) { + // If a candidate was already reachable by both, no update needed. + continue; + } + if (reachability_->IsReachable(instr2, instr)) { + reachability_->FastSetReachabilityToUnion({instr, instr1}, instr); + } + if (reachability_->IsReachable(instr1, instr)) { + reachability_->FastSetReachabilityToUnion({instr, instr2}, instr); + } + } +} + +bool MultiOutputFusion::Perform() { + int changed = false; + // Pick the top candidate from queue and try to merge. + while (!worklist_.empty()) { + if (fuel_ <= 0) { + VLOG(2) << "No fusing: run out of fuel."; + break; + } + ToBeFused candidate = worklist_.top(); + worklist_.pop(); + + HloInstruction* instr1 = candidate.instr1; + HloInstruction* instr2 = candidate.instr2; + + if (is_fused(instr1) || is_fused(instr2)) { + continue; + } + + VLOG(1) << "Considering candidate profit_score=" << candidate.score + << "\n\t\tinstr1 = " << instr1->ToString() + << "\n\t\tinstr2 = " << instr2->ToString(); + + if (LegalToFuse(instr1, instr2)) { + VLOG(1) << "Fuse!"; + VLOG(2) << "Before multi_output_fusion:"; + VLOG(2) << "instr1: " << instr1->ToString(); + VLOG(2) << "\n" + << instr1->fused_instructions_computation()->ToString( + HloPrintOptions().set_indent_amount(1)); + VLOG(2) << "instr2: " << instr2->ToString(); + if (instr2->opcode() == HloOpcode::kFusion) { + VLOG(2) << "\n" + << instr2->fused_instructions_computation()->ToString( + HloPrintOptions().set_indent_amount(1)); + } + HloInstruction* ret = Fuse(instr1, instr2); + set_is_fused(ret == instr1 ? instr2 : instr1); + Update(instr1, instr2); + changed = true; + VLOG(2) << "After fusion, \t this: " << ret->name() << "\n" + << ret->fused_instructions_computation()->ToString( + HloPrintOptions().set_indent_amount(1)); + auto users = ret->users(); + --fuel_; + } + } + if (DoProducerConsumerMultiOutputFusion()) { + changed = true; + } + return changed; +} + +bool MultiOutputFusion::DoProducerConsumerMultiOutputFusion() { return false; } +} // namespace xla diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h new file mode 100644 index 0000000000000000000000000000000000000000..d23822e33e11ede0c5cac97e9fe2b0c3dc88cf3d --- /dev/null +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -0,0 +1,169 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace xla { + +// This class implements the fusing of sibling fusion instructions that sharing +// common operands. +// It constructs the following associated data structures. +// (1) candidates_: stores the instruction and the set of instructions it can +// fuse to. +// (2) candidates_index_: maps instruction to id. +// (3) reachability_: reachability map in this computation. +// (4) all_fusion_candidates_: the vector of candidate instructions. +// (5) worklist_: a priority queue that contains pairs of instructions to be +// fused and their fusion profit scores. +// +// Function Perform() applies the optimization. It picks up the most profitable +// pair in the worklist_, check if it's legal to fuse and fuse the pair. +// After fusion, it updates the associated structure such as reachability_, +// candidates_ and worklist_. +// Note that the reachability map is updated based on the original computation. +// This works because the reachability is monotonically increasing with +// instruction fusion. +class MultiOutputFusion : public HloPassInterface { + public: + MultiOutputFusion(int64 fuel) : fuel_(fuel) {} + + tensorflow::StringPiece name() const override { + return "multi_output_fusion"; + } + + // Run multi-output fusion on the given module. Returns whether the module + // was changed. + StatusOr Run(HloModule* module) override; + + protected: + // Main entry for the optimization. Returns true if the optimization happens. + bool Perform(); + + // Test if instr1 and instr2 have the compatible shapes that can be legally + // fused. + virtual bool ShapesCompatibleForFusion(HloInstruction* instr1, + HloInstruction* instr2) = 0; + + // Whether the instruction is a candidate for fusion. + virtual bool IsFusible(HloInstruction* instr) = 0; + + // This function estimates the savings by merging instr1 and instr2 into one + // multi-output fusion instruction. + virtual int64 GetProfit(HloInstruction* instr1, HloInstruction* instr2) = 0; + + // Whether fusing the instruction can reduce memory reads. + virtual bool IsProfitableOperand(HloInstruction* instr); + + // Test if it's legal to fuse instr1 and instr2 into one fusion instruction. + virtual bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2); + + // Recompute reachability for the current computation. + void RecomputeReachability(); + + // Returns the reachability map for the current computation. + HloReachabilityMap* reachability() const { return reachability_.get(); } + + // Returns the computation for the pass. + HloComputation* computation() const { return computation_; } + + // Update the reachability map after fusing instr1 and instr2. + void UpdateReachability( + HloInstruction* instr1, HloInstruction* instr2, + tensorflow::gtl::ArraySlice instrs_to_update, + const std::function& skip = nullptr); + + // Hook for multi-output fusion along producer-consumer edges. + // Returns whether any instructions were fused. + // + // TODO(b/80420762): Perform producer-consumer multi-output fusion in + // InstructionFusion instead. + virtual bool DoProducerConsumerMultiOutputFusion(); + + private: + // Fuse HloInstrctuion instr1 and instr2 and return the fused instruction. + // The other instruction is removed from its parent computation. + HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2); + + // Update the internal data structures after instr1 and instr2 are fused into + // one fusion instruction. + void Update(HloInstruction* instr1, HloInstruction* instr2); + + // Optimization fuel is a compiler debugging technique that makes an + // optimization pass stop what it is doing after having made N changes to the + // program, where N is the fuel. By varying N, this can be used to find the + // first single change that makes a test fail. + int64 fuel_; + + // Computation for the pass. + HloComputation* computation_; + + // An internal data structure for each instruction in current computation. + // When an instruction is removed, member 'hlo' is set to nullptr. + struct FusionCandidate { + HloInstruction* hlo; + std::list> fusibles; + explicit FusionCandidate(HloInstruction* hlo) : hlo(hlo) {} + }; + std::vector candidates_; + + // A map that maps an instruction to the index_. + tensorflow::gtl::FlatMap candidates_index_; + + // The reachability map of current computation. + std::unique_ptr reachability_; + + // This stores all the candidate instructions in current computation. + std::vector all_fusion_candidates_; + + // The pair of candidates to be fused and the profit score. + struct ToBeFused { + HloInstruction* instr1; + HloInstruction* instr2; + int64 score; + ToBeFused(HloInstruction* instr1, HloInstruction* instr2, int64 score) + : instr1(instr1), instr2(instr2), score(score) {} + bool operator<(const ToBeFused& rhs) const { return score < rhs.score; } + }; + std::priority_queue worklist_; + + int64 get_candidate_id(HloInstruction* instr) { + return FindOrDie(candidates_index_, instr); + } + + bool is_fused(HloInstruction* instr) { + return candidates_[get_candidate_id(instr)].hlo == nullptr; + } + + void set_is_fused(HloInstruction* instr) { + candidates_[get_candidate_id(instr)].hlo = nullptr; + } + + bool is_connected(HloInstruction* instr1, HloInstruction* instr2) { + return reachability_->IsConnected(instr1, instr2); + } +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_ diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index 0f26a025bf125f70199637894741540f89eae7e5..49ec38eb62c7b51c7a2d301d882cef032b288036 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -155,20 +155,15 @@ HloInstruction* UpdateOperand(const HloInstruction* first_reshape_operand, case HloOpcode::kConstant: { if (first_reshape_operand->opcode() == HloOpcode::kReshape) { VLOG(5) << "Adding reshape to kConstant operand"; - HloInstruction* reshape = computation->AddInstruction( + return computation->AddInstruction( HloInstruction::CreateReshape(new_shape, operand)); - operand->SetupDerivedInstruction(reshape); - return reshape; } else { CHECK(first_reshape_operand->opcode() == HloOpcode::kTranspose); VLOG(5) << "Adding transpose to kConstant operand"; std::vector inverse_permutation = InversePermutation(first_reshape_operand->dimensions()); - HloInstruction* transpose = - computation->AddInstruction(HloInstruction::CreateTranspose( - new_shape, operand, inverse_permutation)); - operand->SetupDerivedInstruction(transpose); - return transpose; + return computation->AddInstruction(HloInstruction::CreateTranspose( + new_shape, operand, inverse_permutation)); } } case HloOpcode::kRng: { diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 82be6bcf4f05f01249028a47ede4b3fc9fe31722..da3b622bfae8ac5132f9f95070ee41674e79b5b8 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service/source_map_util.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_layout.h" @@ -62,55 +61,28 @@ namespace xla { namespace { -// Records the arguments used to invoke a computation in a SessionModule -// proto. -Status RecordArguments( - const tensorflow::gtl::ArraySlice arguments, - se::StreamExecutor* executor, TransferManager* transfer_manager, - SessionModule* module) { - module->clear_arguments(); - for (const ShapedBuffer* argument : arguments) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr literal, - transfer_manager->TransferLiteralFromDevice(executor, *argument)); - *module->add_arguments() = literal->ToProto(); - } - return Status::OK(); -} - -// Records the result of a computation in a SessionModule proto. -Status RecordResult(const ShapedBuffer& result, se::StreamExecutor* executor, - TransferManager* transfer_manager, SessionModule* module) { - module->clear_result(); - TF_ASSIGN_OR_RETURN( - std::unique_ptr literal, - transfer_manager->TransferLiteralFromDevice(executor, result)); - *module->mutable_result() = literal->ToProto(); - return Status::OK(); -} - // Records the arguments used to invoke a computation in an HloSnapshot proto. Status RecordArguments( const tensorflow::gtl::ArraySlice arguments, - se::StreamExecutor* executor, TransferManager* transfer_manager, + se::Stream* stream, TransferManager* transfer_manager, HloSnapshot* module) { module->clear_arguments(); for (const ShapedBuffer* argument : arguments) { TF_ASSIGN_OR_RETURN( std::unique_ptr literal, - transfer_manager->TransferLiteralFromDevice(executor, *argument)); + transfer_manager->TransferLiteralFromDevice(stream, *argument)); *module->add_arguments() = literal->ToProto(); } return Status::OK(); } // Records the result of a computation in a HloSnapshot proto. -Status RecordResult(const ShapedBuffer& result, se::StreamExecutor* executor, +Status RecordResult(const ShapedBuffer& result, se::Stream* stream, TransferManager* transfer_manager, HloSnapshot* module) { module->clear_result(); TF_ASSIGN_OR_RETURN( std::unique_ptr literal, - transfer_manager->TransferLiteralFromDevice(executor, result)); + transfer_manager->TransferLiteralFromDevice(stream, result)); *module->mutable_result() = literal->ToProto(); return Status::OK(); } @@ -219,21 +191,17 @@ Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, return Status::OK(); } -Status Service::ValidateResultShapeWithLayout(const Shape& shape_with_layout, - const Shape& result_shape) const { - if (!ShapeUtil::Compatible(shape_with_layout, result_shape)) { +Status Service::ValidateResultShape(const Shape& client_shape, + const Shape& result_shape) const { + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(client_shape)); + if (!ShapeUtil::Compatible(client_shape, result_shape)) { return InvalidArgument( "Shape used to set computation result layout %s is not compatible " "with result shape %s", - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), + ShapeUtil::HumanStringWithLayout(client_shape).c_str(), ShapeUtil::HumanString(result_shape).c_str()); } - if (!LayoutUtil::HasLayout(shape_with_layout)) { - return InvalidArgument( - "Shape used to set computation result layout %s does not have layout", - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); - } - return ShapeUtil::ValidateShape(shape_with_layout); + return Status::OK(); } StatusOr>> @@ -276,10 +244,8 @@ StatusOr> Service::CreateModuleConfig( tensorflow::gtl::ArraySlice argument_shapes, const ExecutionOptions* execution_options) { auto config = MakeUnique(program_shape); - ComputationLayout* host_computation_layout = - config->mutable_host_entry_computation_layout(); - ComputationLayout* device_computation_layout = - config->mutable_device_entry_computation_layout(); + ComputationLayout* computation_layout = + config->mutable_entry_computation_layout(); if (program_shape.parameters_size() != argument_shapes.size()) { return InvalidArgument("computation takes %d parameters, but %zu given", program_shape.parameters_size(), @@ -296,32 +262,22 @@ StatusOr> Service::CreateModuleConfig( i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), ShapeUtil::HumanString(*argument_shapes[i]).c_str()); } - TF_RETURN_IF_ERROR(host_computation_layout->mutable_parameter_layout(i) - ->CopyLayoutFromShape(*argument_shapes[i])); - TF_RETURN_IF_ERROR(device_computation_layout->mutable_parameter_layout(i) - ->CopyLayoutFromShape(*argument_shapes[i])); + TF_RETURN_IF_ERROR( + computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( + *argument_shapes[i])); } if (execution_options != nullptr && execution_options->has_shape_with_output_layout()) { const auto& shape_with_output_layout = execution_options->shape_with_output_layout(); - TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(shape_with_output_layout, - program_shape.result())); TF_RETURN_IF_ERROR( - host_computation_layout->mutable_result_layout()->CopyLayoutFromShape( - shape_with_output_layout)); + ValidateResultShape(shape_with_output_layout, program_shape.result())); TF_RETURN_IF_ERROR( - device_computation_layout->mutable_result_layout()->CopyLayoutFromShape( + computation_layout->mutable_result_layout()->CopyLayoutFromShape( shape_with_output_layout)); } else { // If the result layout is not set, then choose the default. - // TODO(b/29118294): Allow the compiler to choose a better layout in this - // case. - // TODO(b/78356948): We are forcing the default layout here. We should fix - // clients which expect a default layout, to be explicit about it, by - // passing the proper ExecutionOptions with shape_with_output_layout set. - host_computation_layout->mutable_result_layout()->SetToDefaultLayout(); - device_computation_layout->mutable_result_layout()->SetToDefaultLayout(); + computation_layout->mutable_result_layout()->SetToDefaultLayout(); } config->set_replica_count(options_.number_of_replicas()); @@ -376,8 +332,8 @@ StatusOr>> Service::BuildExecutables( module_protos[i]->entry_computation_name().c_str()); TF_RETURN_IF_ERROR( Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot)); - hlo_snapshots.push_back(std::move(hlo_snapshot)); } + hlo_snapshots.push_back(std::move(hlo_snapshot)); } VLOG(1) << "Computations:"; @@ -409,22 +365,6 @@ StatusOr>> Service::BuildExecutables( return std::move(executables); } -Status Service::ValidateEntryComputationLayout(HloModule* module) { - const ComputationLayout& on_device = - module->device_entry_computation_layout(); - for (int64 i = 0; i < on_device.parameter_count(); ++i) { - TF_RET_CHECK(ShapeUtil::Equal( - on_device.parameter_shape(i), - execute_backend_->transfer_manager()->HostShapeToDeviceShape( - module->host_entry_computation_layout().parameter_shape(i)))); - } - TF_RET_CHECK(ShapeUtil::Equal( - module->device_entry_computation_layout().result_shape(), - execute_backend_->transfer_manager()->HostShapeToDeviceShape( - module->host_entry_computation_layout().result_shape()))); - return Status::OK(); -} - StatusOr> Service::ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice executables, @@ -526,7 +466,7 @@ Service::ExecuteParallelAndRegisterResult( HloExecutionProfile hlo_profile(&executable->hlo_profile_printer_data(), &executable->hlo_profile_index_map()); TF_RETURN_IF_ERROR( - executable->PopulateExecutionProfile(&hlo_profile, stream->parent())); + executable->PopulateExecutionProfile(&hlo_profile, stream)); XLA_LOG_LINES( tensorflow::INFO, hlo_profile.ToString(streams[0]->parent()->GetDeviceDescription())); @@ -720,7 +660,7 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, request.execution_options())); VLOG(3) << "ExecuteGraphParallel created HloModuleConfig computation layout: " - << module_config->host_entry_computation_layout().ToString(); + << module_config->entry_computation_layout().ToString(); // Adds to the vectors to build and execute the computations after the loop. all_arguments.push_back(replicated_arguments); @@ -749,6 +689,17 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, executable_ptrs.push_back(executable.get()); } + for (int i = 0; i < executable_ptrs.size(); i++) { + if (executable_ptrs[i]->dumping_snapshot()) { + TF_ASSIGN_OR_RETURN(auto stream, + execute_backend_->BorrowStream( + all_executors[i][0]->device_ordinal())); + TF_RETURN_IF_ERROR(RecordArguments(all_arguments[i].front(), stream.get(), + execute_backend_->transfer_manager(), + executable_ptrs[i]->hlo_snapshot())); + } + } + // Execute the generated executables in parallel and return the device // handles for each computation's output. ExecutionProfile profile; @@ -764,6 +715,20 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, *result->add_responses() = response; } + for (int i = 0; i < executable_ptrs.size(); i++) { + if (executable_ptrs[i]->dumping_snapshot()) { + TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer, + allocation_tracker_.ResolveForReplica(outputs[i], 0)); + TF_ASSIGN_OR_RETURN(auto stream, + execute_backend_->BorrowStream(all_executors[i][0])); + TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(), + execute_backend_->transfer_manager(), + executable_ptrs[i]->hlo_snapshot())); + // Dump out the ith snapshot. + TF_RETURN_IF_ERROR(executable_ptrs[i]->DumpHloSnapshot()); + } + } + VLOG(1) << "successfully completed 'execute-graph-parallel' request"; return Status::OK(); } @@ -856,13 +821,15 @@ StatusOr> Service::BuildExecutable( TF_ASSIGN_OR_RETURN( module, backend->compiler()->RunHloPasses(std::move(module), executor, device_allocator)); - // Check that on-host and on-device shapes are consistent. - TF_RETURN_IF_ERROR(ValidateEntryComputationLayout(module.get())); TF_ASSIGN_OR_RETURN(std::unique_ptr executable, backend->compiler()->RunBackend( std::move(module), executor, device_allocator)); + if (!execution_directory_path.empty()) { + executable->set_hlo_snapshot(std::move(hlo_snapshot)); + } + return std::move(executable); } @@ -900,12 +867,14 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, execute_backend_->default_stream_executor(), /*device_allocator=*/nullptr)); + TF_ASSIGN_OR_RETURN(auto stream, + execute_backend_->BorrowStream( + execute_backend_->default_stream_executor())); if (executable->dumping_snapshot()) { executable->hlo_snapshot()->set_execution_platform( execute_backend_->platform()->Name()); TF_RETURN_IF_ERROR(RecordArguments( - replicated_arguments.front(), - execute_backend_->default_stream_executor(), + replicated_arguments.front(), stream.get(), execute_backend_->transfer_manager(), executable->hlo_snapshot())); } @@ -919,9 +888,9 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, TF_ASSIGN_OR_RETURN( const ShapedBuffer* result_buffer, allocation_tracker_.ResolveForReplica(result->output(), 0)); - TF_RETURN_IF_ERROR(RecordResult( - *result_buffer, execute_backend_->default_stream_executor(), - execute_backend_->transfer_manager(), executable->hlo_snapshot())); + TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(), + execute_backend_->transfer_manager(), + executable->hlo_snapshot())); TF_RETURN_IF_ERROR(executable->DumpHloSnapshot()); } @@ -959,14 +928,13 @@ Status Service::TransferToClient(const TransferToClientRequest* arg, return_shape = &shaped_buffer->on_host_shape(); } - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - execute_backend_->stream_executor(shaped_buffer->device_ordinal())); + TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream( + shaped_buffer->device_ordinal())); TF_ASSIGN_OR_RETURN( std::unique_ptr result_literal, execute_backend_->transfer_manager()->TransferLiteralFromDevice( - executor, *shaped_buffer)); + stream.get(), *shaped_buffer)); if (LayoutUtil::LayoutsInShapesEqual(*return_shape, result_literal->shape())) { @@ -1016,9 +984,10 @@ Status Service::TransferToServer(const TransferToServerRequest* arg, execute_backend_->transfer_manager()->AllocateScopedShapedBuffer( shape, execute_backend_->memory_allocator(), executor->device_ordinal())); + TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor)); TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralToDevice( - executor, *literal, shaped_buffer)); + stream.get(), *literal, shaped_buffer)); replicated_buffers.emplace_back(std::move(shaped_buffer)); } TF_ASSIGN_OR_RETURN(*result->mutable_data(), diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 422bb95657e339dd20b64f946504827fb8b7ca41..47d196fb2aaee897ce1fd3745129af10bf5b2d2d 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -26,15 +26,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/allocation_tracker.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/channel_tracker.h" -#include "tensorflow/compiler/xla/service/compilation_cache.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/execution_tracker.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/service_interface.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -196,9 +193,6 @@ class Service : public ServiceInterface { const ExecutionOptions& execution_options, tensorflow::gtl::ArraySlice arguments); - // Assert that host- and device-shapes are in a consistent state. - Status ValidateEntryComputationLayout(HloModule* module); - protected: friend class LocalExecutable; @@ -269,11 +263,11 @@ class Service : public ServiceInterface { // will be the result of this computation. Status ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result); - // Convenience function which checks whether the given shape_with_layout + // Convenience function which checks whether the given client_shape // (presumably passed by the client to set the result layout) is valid for the // given computation result shape. - Status ValidateResultShapeWithLayout(const Shape& shape_with_layout, - const Shape& result_shape) const; + Status ValidateResultShape(const Shape& client_shape, + const Shape& result_shape) const; // Returns the stream executors assigned to the replicas represented by the // given device handle. Each device_handle is a virtual replicated device that @@ -298,9 +292,6 @@ class Service : public ServiceInterface { // Tracks asynchronously launched executions via the API. ExecutionTracker execution_tracker_; - // Cache containing previously built Executables. - CompilationCache compilation_cache_; - // Backend to compile and execute computations on. std::unique_ptr execute_backend_; diff --git a/tensorflow/compiler/xla/service/session.proto b/tensorflow/compiler/xla/service/session.proto deleted file mode 100644 index bb8d1cd2a106ea3e5bb61eee5052bd60c38cd0e2..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/session.proto +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This proto file defines messages which store the state of XLA -// computations within the XLA service. A computation is stored as a record -// of the operation requests used to build it. -syntax = "proto3"; - -import "tensorflow/compiler/xla/xla_data.proto"; - -package xla; - -// Describes a single operation request. -message OperationRequest { - ComputationDataHandle output_handle = 1; - Shape output_shape = 2; - - // For operations which call embedded computations such as "Map", these are - // the version(s) that the embedded computation should be called at. A version - // value of a computation is the ComputationDataHandle of the root of the - // computation at the point in time. - // - // "Call", "Map", "Reduce", and "ReduceWindow" operations take a single - // embedded computation so this field will have a single value for those - // operations. - // - // "While" operation takes two; index 0 is the "condition" version and index 1 - // is the "body" version. - repeated int64 embedded_computation_versions = 3; - - // The actual request, which in itself is a tagged union of all possible - // operation request types. - OpRequest request = 4; -} - -// Describes a sequence of operation requests which define an XLA -// computation. -message SessionComputation { - string name = 1; - - // The ComputationHandle used to refer to this computation in the XLA - // service. - ComputationHandle computation_handle = 2; - - // Map from ComputationDataHandle value to operation request. The highest - // ComputationDataHandle value corresponds to the root of the computation. - map requests = 3; -} - -// Describes a group of SessionComputations with an "entry point" computation -// that may refer to the other non-entry (AKA embedded) computations. -// -// This message is used to serialize a computation that has been built via the -// XLA service API, along with its dependencies, for purposes such as -// analysis/replay/file-storage. -message SessionModule { - // The entry computation, which was requested for serialization. This may have - // referred to embedded computations, which are reflected below. - SessionComputation entry = 1; - - // Embedded computations that are transitively referred to by the entry - // computation. - repeated SessionComputation embedded_computations = 2; - - // The arguments passed to the computation. - repeated LiteralProto arguments = 3; - - // The result of the computation. - LiteralProto result = 4; - - // The name of the platform used to run the computation. - string execution_platform = 5; -} diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index d624f548b1ba65e6f6dfd7b329e8c86ab29112a0..4606d8f202f8a63cbc6efae8fa096fdff2f1a010 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -44,147 +44,18 @@ namespace xla { namespace { -// Return the UnaryOperation proto enum value associated with the given HLO -// opcode. -UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kAbs: - return UNOP_ABS; - case HloOpcode::kCeil: - return UNOP_CEIL; - case HloOpcode::kClz: - return UNOP_CLZ; - case HloOpcode::kCos: - return UNOP_COS; - case HloOpcode::kExp: - return UNOP_EXP; - case HloOpcode::kExpm1: - return UNOP_EXPM1; - case HloOpcode::kFloor: - return UNOP_FLOOR; - case HloOpcode::kImag: - return UNOP_IMAG; - case HloOpcode::kIsFinite: - return UNOP_IS_FINITE; - case HloOpcode::kLog: - return UNOP_LOG; - case HloOpcode::kLog1p: - return UNOP_LOG1P; - case HloOpcode::kNot: - return UNOP_NOT; - case HloOpcode::kNegate: - return UNOP_NEGATE; - case HloOpcode::kReal: - return UNOP_REAL; - case HloOpcode::kRoundNearestAfz: - return UNOP_ROUND_NEAREST_AFZ; - case HloOpcode::kSign: - return UNOP_SIGN; - case HloOpcode::kSin: - return UNOP_SIN; - case HloOpcode::kSort: - return UNOP_SORT; - case HloOpcode::kTanh: - return UNOP_TANH; - default: - LOG(FATAL) << "Unhandled opcode for conversion to unary operation: " - << opcode; - } -} - -// Return the BinaryOperation proto enum value associated with the given HLO -// opcode. -BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kAtan2: - return BINOP_ATAN2; - case HloOpcode::kComplex: - return BINOP_COMPLEX; - case HloOpcode::kMultiply: - return BINOP_MUL; - case HloOpcode::kAdd: - return BINOP_ADD; - case HloOpcode::kSubtract: - return BINOP_SUB; - case HloOpcode::kDivide: - return BINOP_DIV; - case HloOpcode::kEq: - return BINOP_EQ; - case HloOpcode::kGe: - return BINOP_GE; - case HloOpcode::kGt: - return BINOP_GT; - case HloOpcode::kLe: - return BINOP_LE; - case HloOpcode::kLt: - return BINOP_LT; - case HloOpcode::kNe: - return BINOP_NE; - case HloOpcode::kMaximum: - return BINOP_MAX; - case HloOpcode::kMinimum: - return BINOP_MIN; - case HloOpcode::kPower: - return BINOP_POW; - case HloOpcode::kRemainder: - return BINOP_REM; - case HloOpcode::kOr: - return BINOP_OR; - case HloOpcode::kAnd: - return BINOP_AND; - case HloOpcode::kShiftLeft: - return BINOP_SHIFT_LEFT; - case HloOpcode::kShiftRightArithmetic: - return BINOP_SHIFT_RIGHT_ARITHMETIC; - case HloOpcode::kShiftRightLogical: - return BINOP_SHIFT_RIGHT_LOGICAL; - default: - LOG(FATAL) << "unhandled opcode " << opcode; - } -} - -// Return the TernaryOperation proto enum value associated with the given HLO -// opcode. -TernaryOperation OpcodeToTernaryOperation(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kClamp: - return TRIOP_CLAMP; - case HloOpcode::kSelect: - return TRIOP_SELECT; - default: - LOG(FATAL) << "unhandled opcode " << opcode; - } -} - -// Return the VariadicOperation proto enum value associated with the given HLO -// opcode. -VariadicOperation OpcodeToVariadicOperation(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kTuple: - return VAROP_TUPLE; - default: - LOG(FATAL) << "unhandled opcode " << opcode; - } -} - // Returns true if no element is present in slice more than once. bool AllUnique(tensorflow::gtl::ArraySlice slice) { return std::set(slice.begin(), slice.end()).size() == slice.size(); } -Status ExpectNotTupleOrOpaque(const Shape& shape, - tensorflow::StringPiece op_type) { - if (ShapeUtil::IsTuple(shape)) { - return InvalidArgument("Expected non-tuple argument for %s, but got %s.", - std::string(op_type).c_str(), - ShapeUtil::HumanString(shape).c_str()); - } else if (ShapeUtil::IsOpaque(shape)) { - return InvalidArgument("Expected non-opaque argument for %s, but got %s.", +Status ExpectArray(const Shape& shape, tensorflow::StringPiece op_type) { + if (!ShapeUtil::IsArray(shape)) { + return InvalidArgument("Expected array argument for %s, but got %s.", std::string(op_type).c_str(), ShapeUtil::HumanString(shape).c_str()); - } else { - return Status::OK(); } + return Status::OK(); } Status VerifyReducerShape(const ProgramShape& reducer_shape, @@ -321,84 +192,80 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return shape; } - return InferUnaryOpShape(OpcodeToUnaryOperation(opcode), shape); -} + TF_RETURN_IF_ERROR(ExpectArray(shape, "operand of unary operation")); -/* static */ StatusOr ShapeInference::InferUnaryOpShape( - UnaryOperation operation, const Shape& arg) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of unary operation")); - - TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(arg)); - switch (operation) { - case UNOP_FLOOR: - case UNOP_CEIL: - if (!ShapeUtil::ElementIsFloating(arg)) { + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); + switch (opcode) { + case HloOpcode::kFloor: + case HloOpcode::kCeil: + if (!ShapeUtil::ElementIsFloating(shape)) { return InvalidArgument( "Expected element type in shape to be floating for floor/ceil " "operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return arg; - case UNOP_COS: - case UNOP_SIN: - case UNOP_EXP: - case UNOP_EXPM1: - case UNOP_LOG: - case UNOP_LOG1P: - case UNOP_TANH: - if (!ShapeUtil::ElementIsFloating(arg) && - !ShapeUtil::ElementIsComplex(arg)) { + return shape; + case HloOpcode::kCos: + case HloOpcode::kSin: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kTanh: + if (!ShapeUtil::ElementIsFloating(shape) && + !ShapeUtil::ElementIsComplex(shape)) { return InvalidArgument( "Expected element type in shape to be floating or complex for " "sin/cos/exp/log/tanh operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return arg; - case UNOP_REAL: - case UNOP_IMAG: - if (!ShapeUtil::ElementIsComplex(arg)) { + return shape; + case HloOpcode::kReal: + case HloOpcode::kImag: + if (!ShapeUtil::ElementIsComplex(shape)) { return InvalidArgument( "Expected element type in shape to be complex for real/imag " "operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return ShapeUtil::ChangeElementType(arg, F32); - case UNOP_ABS: - if (ShapeUtil::ElementIsComplex(arg)) { + return ShapeUtil::ChangeElementType(shape, F32); + case HloOpcode::kAbs: + if (ShapeUtil::ElementIsComplex(shape)) { return ShapeUtil::ChangeElementType( - arg, primitive_util::ComplexComponentType(arg.element_type())); + shape, primitive_util::ComplexComponentType(shape.element_type())); } - return arg; - case UNOP_CLZ: - case UNOP_NEGATE: - case UNOP_ROUND_NEAREST_AFZ: - case UNOP_SIGN: - case UNOP_SORT: - return arg; - - case UNOP_NOT: - if (arg.element_type() != PRED && - !primitive_util::IsIntegralType(arg.element_type())) { + return shape; + case HloOpcode::kClz: + case HloOpcode::kNegate: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kSign: + case HloOpcode::kSort: + return shape; + + case HloOpcode::kNot: + if (shape.element_type() != PRED && + !primitive_util::IsIntegralType(shape.element_type())) { return InvalidArgument( "Expected pred or an integral element type in argument to Not " "operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return arg; + return shape; - case UNOP_IS_FINITE: - if (!ShapeUtil::ElementIsFloating(arg)) { + case HloOpcode::kIsFinite: + if (!ShapeUtil::ElementIsFloating(shape)) { return InvalidArgument( - "Expected element type in shape to be floating point for IsFinite " + "Expected element type in shape to be floating " + "point for IsFinite " "operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return ShapeUtil::ChangeElementType(arg, PRED); + return ShapeUtil::ChangeElementType(shape, PRED); default: return InvalidArgument( "Unknown operation for unary shape inference: \"%s\".", - UnaryOperation_Name(operation).c_str()); + HloOpcodeString(opcode).c_str()); } } @@ -415,8 +282,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, const Shape* arg_shape = nullptr; PrimitiveType element_type = PRIMITIVE_TYPE_INVALID; for (const Shape* shape : arg_shapes) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(*shape, "operand of concatenation")); + TF_RETURN_IF_ERROR(ExpectArray(*shape, "operand of concatenation")); if (!arg_shape) { arg_shape = shape; element_type = arg_shape->element_type(); @@ -463,6 +329,17 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::MakeShape(element_type, new_dimensions); } +/* static */ StatusOr ShapeInference::InferGenerateTokenShape( + tensorflow::gtl::ArraySlice arg_shapes) { + for (const Shape* arg_shape : arg_shapes) { + if (arg_shape->element_type() != TOKEN) { + return InvalidArgument( + "Operands of token instructions must be TOKEN types."); + } + } + return ShapeUtil::MakeTokenShape(); +} + /* static */ StatusOr ShapeInference::InferConvertShape( const Shape& operand_shape, PrimitiveType new_element_type) { auto old_element_type = operand_shape.element_type(); @@ -473,12 +350,13 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(operand_shape).c_str(), PrimitiveType_Name(new_element_type).c_str()); } - if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) { + if (!ShapeUtil::IsArray(operand_shape) || + !primitive_util::IsArrayType(new_element_type)) { // Note: we may want to support tuple conversions via this operation in the // future, by recursing into the tuple elements to check all sub-conversions // are valid. For now we just reject them, though. return InvalidArgument( - "Convert does not allow tuples, so cannot convert from %s to %s.", + "Convert does not allow non-arrays, so cannot convert from %s to %s.", ShapeUtil::HumanString(operand_shape).c_str(), PrimitiveType_Name(new_element_type).c_str()); } @@ -495,7 +373,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(operand_shape).c_str(), PrimitiveType_Name(new_element_type).c_str()); } - if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) { + if (!ShapeUtil::IsArray(operand_shape) || + !primitive_util::IsArrayType(new_element_type)) { // Note: we may want to support tuple conversions via this operation in the // future, by recursing into the tuple elements to check all sub-conversions // are valid. For now we just reject them, though. @@ -542,7 +421,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, /* static */ StatusOr ShapeInference::InferPadShape( const Shape& operand_shape, const Shape& padding_value_shape, const PaddingConfig& padding_config) { - if (ShapeUtil::IsTuple(operand_shape)) { + if (!ShapeUtil::IsArray(operand_shape)) { return InvalidArgument( "Pad operation does not support tuple-shape operands."); } @@ -681,8 +560,8 @@ Status ValidateDotDimensionNumbers( /* static */ StatusOr ShapeInference::InferDotOpShape( const Shape& lhs, const Shape& rhs, const DotDimensionNumbers& dimension_numbers) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of dot")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of dot")); + TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot")); + TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot")); auto fail = [lhs, rhs](const string& addendum) -> Status { string message = tensorflow::strings::Printf( @@ -768,8 +647,9 @@ Status ValidateDotDimensionNumbers( } /* static */ StatusOr -ShapeInference::InferDegenerateDimensionBroadcastShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs) { +ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, + const Shape& lhs, + const Shape& rhs) { TF_RET_CHECK(ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)); // The shapes have to be compatible. That is, if some dimension d has a @@ -787,7 +667,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } else { return InvalidArgument( "Binary op %s with incompatible shapes: %s and %s.", - BinaryOperation_Name(operation).c_str(), + HloOpcodeString(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str()); } @@ -797,8 +677,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } /* static */ StatusOr ShapeInference::InferInDimBroadcastShape( - BinaryOperation operation, const Shape& smaller_shape, - const Shape& larger_shape, + const Shape& smaller_shape, const Shape& larger_shape, tensorflow::gtl::ArraySlice broadcast_dimensions) { if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) { // Reject "magic" inference for binops on different shapes, requiring @@ -899,18 +778,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } /* static */ StatusOr ShapeInference::InferElementwiseBinaryOpShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs, + HloOpcode operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(lhs, "lhs of elementwise binary operation")); - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(rhs, "rhs of elementwise binary operation")); + TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation")); + TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation")); if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Binary op %s with different element types: %s and %s.", - BinaryOperation_Name(operation).c_str(), - ShapeUtil::HumanString(lhs).c_str(), + HloOpcodeString(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str()); } @@ -943,10 +819,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? rhs : lhs; // After InDim broadcasting, perform degenerate dimensions broadcasting. - TF_ASSIGN_OR_RETURN( - Shape indim_broadcast_shape, - InferInDimBroadcastShape(operation, smaller_shape, larger_shape, - broadcast_dimensions)); + TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape, + InferInDimBroadcastShape(smaller_shape, larger_shape, + broadcast_dimensions)); return InferDegenerateDimensionBroadcastShape( operation, indim_broadcast_shape, larger_shape); @@ -955,51 +830,44 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferBinaryOpShape( HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) { - return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs->shape(), - rhs->shape(), /*broadcast_dimensions=*/{}); + return InferBinaryOpShape(opcode, lhs->shape(), rhs->shape(), + /*broadcast_dimensions=*/{}); } /* static */ StatusOr ShapeInference::InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { - return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs, rhs, - broadcast_dimensions); -} - -/* static */ StatusOr ShapeInference::InferBinaryOpShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { VLOG(2) << tensorflow::strings::Printf( "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}", - BinaryOperation_Name(operation).c_str(), - ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(), + HloOpcodeString(opcode).c_str(), ShapeUtil::HumanString(lhs).c_str(), + ShapeUtil::HumanString(rhs).c_str(), Join(broadcast_dimensions, ", ").c_str()); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - lhs, tensorflow::strings::StrCat("lhs of binary operation ", - BinaryOperation_Name(operation)))); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - rhs, tensorflow::strings::StrCat("rhs of binary operation ", - BinaryOperation_Name(operation)))); - switch (operation) { - case BINOP_MAX: - case BINOP_MIN: - case BINOP_SUB: - case BINOP_ADD: - case BINOP_ATAN2: - case BINOP_POW: - case BINOP_DIV: - case BINOP_REM: - case BINOP_MUL: - case BINOP_SHIFT_LEFT: - case BINOP_SHIFT_RIGHT_ARITHMETIC: - case BINOP_SHIFT_RIGHT_LOGICAL: - return InferElementwiseBinaryOpShape(operation, lhs, rhs, + TF_RETURN_IF_ERROR( + ExpectArray(lhs, tensorflow::strings::StrCat("lhs of binary operation ", + HloOpcodeString(opcode)))); + TF_RETURN_IF_ERROR( + ExpectArray(rhs, tensorflow::strings::StrCat("rhs of binary operation ", + HloOpcodeString(opcode)))); + switch (opcode) { + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kSubtract: + case HloOpcode::kAdd: + case HloOpcode::kAtan2: + case HloOpcode::kPower: + case HloOpcode::kDivide: + case HloOpcode::kRemainder: + case HloOpcode::kMultiply: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: + return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); - case BINOP_COMPLEX: { + case HloOpcode::kComplex: { if (!ShapeUtil::ElementIsFloating(lhs)) { return InvalidArgument( "Expected element type in shape to be floating for complex compose " @@ -1007,7 +875,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(lhs.element_type()).c_str()); } TF_ASSIGN_OR_RETURN(const Shape& shape, - InferElementwiseBinaryOpShape(operation, lhs, rhs, + InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions)); if (lhs.element_type() == F32 && rhs.element_type() == F32) { return ShapeUtil::ChangeElementType(shape, C64); @@ -1015,8 +883,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return Unimplemented("Complex component type is not implemented."); } } - case BINOP_AND: - case BINOP_OR: + case HloOpcode::kAnd: + case HloOpcode::kOr: if (lhs.element_type() != PRED && !primitive_util::IsIntegralType(lhs.element_type())) { return InvalidArgument( @@ -1024,24 +892,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "got %s.", PrimitiveType_Name(lhs.element_type()).c_str()); } - return InferElementwiseBinaryOpShape(operation, lhs, rhs, + return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); - case BINOP_EQ: - case BINOP_GE: - case BINOP_GT: - case BINOP_LE: - case BINOP_LT: - case BINOP_NE: { + case HloOpcode::kEq: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLe: + case HloOpcode::kLt: + case HloOpcode::kNe: { TF_ASSIGN_OR_RETURN(const Shape& shape, - InferElementwiseBinaryOpShape(operation, lhs, rhs, + InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions)); return ShapeUtil::ChangeElementType(shape, PRED); } default: return Unimplemented( "Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.", - BinaryOperation_Name(operation).c_str(), - lhs.ShortDebugString().c_str(), rhs.ShortDebugString().c_str()); + HloOpcodeString(opcode).c_str(), lhs.ShortDebugString().c_str(), + rhs.ShortDebugString().c_str()); } } @@ -1053,23 +921,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferTernaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) { - return InferTernaryOpShape(OpcodeToTernaryOperation(opcode), lhs, rhs, ehs); -} - -/* static */ StatusOr ShapeInference::InferTernaryOpShape( - TernaryOperation operation, const Shape& lhs, const Shape& rhs, - const Shape& ehs) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(ehs)); - switch (operation) { - case TRIOP_CLAMP: + switch (opcode) { + case HloOpcode::kClamp: return InferClampShape(lhs, rhs, ehs); - case TRIOP_SELECT: + case HloOpcode::kSelect: return InferSelectShape(lhs, rhs, ehs); default: return InvalidArgument("Unknown operation %s.", - TernaryOperation_Name(operation).c_str()); + HloOpcodeString(opcode).c_str()); } } @@ -1077,6 +939,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( HloOpcode opcode, tensorflow::gtl::ArraySlice operands) { std::vector operand_shapes; + operand_shapes.reserve(operands.size()); for (const HloInstruction* operand : operands) { operand_shapes.push_back(&operand->shape()); } @@ -1086,19 +949,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferVariadicOpShape( HloOpcode opcode, tensorflow::gtl::ArraySlice operand_shapes) { - return InferVariadicOpShape(OpcodeToVariadicOperation(opcode), - operand_shapes); -} - -/* static */ StatusOr ShapeInference::InferVariadicOpShape( - VariadicOperation operation, - tensorflow::gtl::ArraySlice operand_shapes) { for (const Shape* shape : operand_shapes) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape)); } - switch (operation) { - case VAROP_TUPLE: { + switch (opcode) { + case HloOpcode::kTuple: { Shape result = ShapeUtil::MakeTupleShape({}); + result.mutable_tuple_shapes()->Reserve(operand_shapes.size()); for (const Shape* shape : operand_shapes) { ShapeUtil::AppendShapeToTuple(*shape, &result); } @@ -1106,7 +963,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } default: return InvalidArgument("Unknown operation %s.", - VariadicOperation_Name(operation).c_str()); + HloOpcodeString(opcode).c_str()); } } @@ -1121,15 +978,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // All arguments must have the same shape. const Shape* arg_shape = arg_shapes[0]; for (size_t i = 1; i < arg_shapes.size(); ++i) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(*arg_shapes[i], "operand of map")); + TF_RETURN_IF_ERROR(ExpectArray(*arg_shapes[i], "operand of map")); if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) { continue; } - if (!ShapeUtil::IsTuple(*arg_shapes[i]) && - !ShapeUtil::IsTuple(*arg_shape) && - ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i], + if (ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) { if (ShapeUtil::IsScalar(*arg_shapes[i])) { continue; @@ -1212,11 +1066,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& operand_shape, const Shape& scale_shape, const Shape& offset_shape, int64 feature_index) { TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm training")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - offset_shape, "offset input of batch norm training")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - scale_shape, "scale input of batch norm training")); + ExpectArray(operand_shape, "operand of batch norm training")); + TF_RETURN_IF_ERROR( + ExpectArray(offset_shape, "offset input of batch norm training")); + TF_RETURN_IF_ERROR( + ExpectArray(scale_shape, "scale input of batch norm training")); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == Status::OK()); @@ -1318,11 +1172,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& offset_shape, const Shape& mean_shape, const Shape& variance_shape, int64 feature_index) { TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm inference")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - offset_shape, "offset input of batch norm inference")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - scale_shape, "scale input of batch norm inference")); + ExpectArray(operand_shape, "operand of batch norm inference")); + TF_RETURN_IF_ERROR( + ExpectArray(offset_shape, "offset input of batch norm inference")); + TF_RETURN_IF_ERROR( + ExpectArray(scale_shape, "scale input of batch norm inference")); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == Status::OK()); @@ -1465,16 +1319,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& operand_shape, const Shape& scale_shape, const Shape& mean_shape, const Shape& var_shape, const Shape& output_grad_shape, int64 feature_index) { + TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of batch norm grad")); TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm grad")); - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(scale_shape, "scale input of batch norm grad")); + ExpectArray(scale_shape, "scale input of batch norm grad")); + TF_RETURN_IF_ERROR(ExpectArray(mean_shape, "mean input of batch norm grad")); + TF_RETURN_IF_ERROR(ExpectArray(var_shape, "var input of batch norm grad")); TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(mean_shape, "mean input of batch norm grad")); - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(var_shape, "var input of batch norm grad")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - output_grad_shape, "output_grad input of batch norm grad")); + ExpectArray(output_grad_shape, "output_grad input of batch norm grad")); TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape)); TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape)); @@ -1623,8 +1474,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferConvolveShape( const Shape& lhs, const Shape& rhs, const Window& window, const ConvolutionDimensionNumbers& dnums) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of convolution")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of convolution")); + TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution")); + TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution")); if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( @@ -1859,7 +1710,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( tensorflow::gtl::ArraySlice operand_shapes) { for (const Shape* operand_shape : operand_shapes) { TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(*operand_shape, "operand of cross replica sum")); + ExpectArray(*operand_shape, "operand of cross replica sum")); } if (operand_shapes.size() == 1) { return *operand_shapes[0]; @@ -1901,8 +1752,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferReduceWindowShape( const Shape& operand_shape, const Shape& init_value_shape, const Window& window, const ProgramShape& to_apply_shape) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of reduce-window")); + TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window")); TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_value_shape, operand_shape.element_type())); return InferWindowOutputShape(operand_shape, window, @@ -1915,7 +1765,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Window& window, const Shape& source_shape, const Shape& init_value_shape, const ProgramShape& scatter_shape) { TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of select-and-scatter")); + ExpectArray(operand_shape, "operand of select-and-scatter")); // Check if the select function has a proper shape of (T,T) -> PRED. if (select_shape.parameters_size() != 2) { @@ -1980,7 +1830,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( Join(starts, ",").c_str(), Join(limits, ",").c_str(), Join(strides, ",").c_str()); }; - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of slice")); + TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice")); VLOG(2) << tensorflow::strings::Printf( "slicing shape %s starts={%s} limits={%s}", ShapeUtil::HumanString(arg).c_str(), Join(starts, ", ").c_str(), @@ -2039,10 +1889,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferDynamicSliceShape( const Shape& operand_shape, const Shape& start_indices_shape, tensorflow::gtl::ArraySlice slice_sizes) { + TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice")); TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic slice")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(start_indices_shape, - "start indices of dynamic slice")); + ExpectArray(start_indices_shape, "start indices of dynamic slice")); VLOG(2) << tensorflow::strings::Printf( "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}", @@ -2100,11 +1949,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& operand_shape, const Shape& update_shape, const Shape& start_indices_shape) { TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic update slice")); + ExpectArray(operand_shape, "operand of dynamic update slice")); TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(update_shape, "update of dynamic update slice")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - start_indices_shape, "start indices of dynamic update slice")); + ExpectArray(update_shape, "update of dynamic update slice")); + TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape, + "start indices of dynamic update slice")); VLOG(2) << tensorflow::strings::Printf( "updating slice of shape %s at dynamic start_indices %s with update " @@ -2172,8 +2021,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /*static */ StatusOr ShapeInference::InferReverseShape( const Shape& operand_shape, tensorflow::gtl::ArraySlice dimensions) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand_shape, "operand of reverse")); + TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse")); if (!AllUnique(dimensions)) { return InvalidArgument("a dimension number is duplicated in reverse"); } @@ -2303,7 +2151,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferBroadcastShape( const Shape& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "operand of broadcast")); + TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast")); for (int64 size : broadcast_sizes) { if (size < 0) { return InvalidArgument("Broadcast with negative dimension size %lld.", @@ -2322,7 +2170,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferReshapeShape( const Shape& operand, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice new_sizes) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "reshape")); + TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape")); Shape inferred_shape = ShapeUtil::MakeShape(operand.element_type(), new_sizes); @@ -2354,7 +2202,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferTransposeShape( const Shape& operand, tensorflow::gtl::ArraySlice dimensions) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "transpose")); + TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose")); std::vector indices(ShapeUtil::Rank(operand)); std::iota(indices.begin(), indices.end(), 0); @@ -2375,9 +2223,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // "degenerate" cases, as with binary elementwise ops. /* static */ StatusOr ShapeInference::InferClampShape( const Shape& min, const Shape& operand, const Shape& max) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max")); + TF_RETURN_IF_ERROR(ExpectArray(min, "clamp min")); + TF_RETURN_IF_ERROR(ExpectArray(operand, "clamp operand")); + TF_RETURN_IF_ERROR(ExpectArray(max, "clamp max")); if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) || !ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) { return InvalidArgument("Clamp with different operand types: %s, %s, %s.", @@ -2576,9 +2424,9 @@ static Status ValidateGatherDimensionNumbers( const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice window_bounds) { TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(input_shape, "input tensor operand gather op")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( - gather_indices_shape, "gather indices operand of gather op")); + ExpectArray(input_shape, "input tensor operand gather op")); + TF_RETURN_IF_ERROR( + ExpectArray(gather_indices_shape, "gather indices operand of gather op")); if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) { return InvalidArgument( diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 9da2c99b4177f08ece8daabaf2922ddd7e947a1b..eef6e62fc8d933452ebc3f9a5b8bc49828455be5 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -46,8 +46,6 @@ class ShapeInference { public: // Infers the shape produced by applying the given unary operation to the // given input shape. - static StatusOr InferUnaryOpShape(UnaryOperation operation, - const Shape& arg); static StatusOr InferUnaryOpShape(HloOpcode opcode, const Shape& shape); static StatusOr InferUnaryOpShape(HloOpcode opcode, @@ -55,9 +53,6 @@ class ShapeInference { // Infers the shape produced by applying the given binary operation to the // given input shapes. - static StatusOr InferBinaryOpShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); static StatusOr InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions); @@ -67,9 +62,6 @@ class ShapeInference { // Infers the shape produced by applying the given ternary operation to the // given input shapes. - static StatusOr InferTernaryOpShape(TernaryOperation operation, - const Shape& lhs, const Shape& rhs, - const Shape& ehs); static StatusOr InferTernaryOpShape(HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs); @@ -80,9 +72,6 @@ class ShapeInference { // Infers the shape produced by applying the given variadic operation to the // given input operand shapes. - static StatusOr InferVariadicOpShape( - VariadicOperation operation, - tensorflow::gtl::ArraySlice operand_shapes); static StatusOr InferVariadicOpShape( HloOpcode opcode, tensorflow::gtl::ArraySlice operand_shapes); @@ -227,6 +216,13 @@ class ShapeInference { static StatusOr InferConcatOpShape( tensorflow::gtl::ArraySlice arg_shapes, int64 dimension); + // Infers the shape produced by a kGenerateToken operation. Trivially this + // shape is always a TOKEN shape. However, ShapeInference serves two purposes: + // inferring shapes and checking operand shapes. This method verifies that the + // operand shapes are all TOKENs. + static StatusOr InferGenerateTokenShape( + tensorflow::gtl::ArraySlice arg_shapes); + // Helper that validates the given operand shape can be converted to the // target output_shape via a convert instruction -- the requirement is that // the shape is identical except for the element type. @@ -279,7 +275,7 @@ class ShapeInference { // the LHS and a single element in the RHS to produce a single output element, // even in the presence of broadcasting of one of the operands over the other. static StatusOr InferElementwiseBinaryOpShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs, + HloOpcode operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions); // Helper for inferring the shape of Clamp ops. @@ -295,7 +291,7 @@ class ShapeInference { // dimension broadcasting (a dimension of size 1 in one operand is broadcast // up to match the size of the dimension in the other operand). static StatusOr InferDegenerateDimensionBroadcastShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs); + HloOpcode operation, const Shape& lhs, const Shape& rhs); // Helper for inferring shapes of binary operations using "InDim" // broadcasting. This is the broadcasting used in the *InDim binary operations @@ -303,8 +299,7 @@ class ShapeInference { // lower-rank shape than larger_shape. Returns the shape that the // smaller_shape is broadcast to. static StatusOr InferInDimBroadcastShape( - BinaryOperation operation, const Shape& smaller_shape, - const Shape& larger_shape, + const Shape& smaller_shape, const Shape& larger_shape, tensorflow::gtl::ArraySlice broadcast_dimensions); TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference); diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 0e61994a786b53a295ef9c9c2287b28fbf754d9b..bafe14d6f45f851924c37908d4c93bbff2dac459 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -101,8 +101,8 @@ class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest { TEST_F(ShapeInferenceTest, UnaryNegateMatrix) { Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); - auto inferred_status = ShapeInference::InferUnaryOpShape( - UnaryOperation::UNOP_NEGATE, matrix_shape); + auto inferred_status = + ShapeInference::InferUnaryOpShape(HloOpcode::kNegate, matrix_shape); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, inferred_status.ValueOrDie())); } @@ -110,14 +110,14 @@ TEST_F(ShapeInferenceTest, UnaryNegateMatrix) { TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) { Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_}); auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, pred_, tuple, tuple); + HloOpcode::kSelect, pred_, tuple, tuple); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(tuple, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_64_48_); + HloOpcode::kSelect, pred_, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } @@ -125,34 +125,34 @@ TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) { TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) { auto predarray = ShapeUtil::MakeShape(PRED, {64, 48}); auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, predarray, matrix_64_48_, matrix_64_48_); + HloOpcode::kSelect, predarray, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, SelectBadShapes) { auto inferred_status_error1 = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_32_64_); + HloOpcode::kSelect, pred_, matrix_64_48_, matrix_32_64_); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), HasSubstr("Operands to select must be the same shape")); auto inferred_status_error2 = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, s32_, matrix_64_48_, matrix_64_48_); + HloOpcode::kSelect, s32_, matrix_64_48_, matrix_64_48_); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), HasSubstr("pred operand must have PRED")); auto inferred_status_error3 = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeShape(PRED, {64}), - matrix_64_48_, matrix_64_48_); + HloOpcode::kSelect, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_, + matrix_64_48_); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), HasSubstr("with non-scalar predicate with dimensionality")); // Tuples have a TUPLE element type and cannot be the pred of a select. auto inferred_status_error4 = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeTupleShape({pred_, pred_}), + HloOpcode::kSelect, ShapeUtil::MakeTupleShape({pred_, pred_}), ShapeUtil::MakeTupleShape({f32_, f32_}), ShapeUtil::MakeTupleShape({f32_, f32_})); ASSERT_FALSE(inferred_status_error4.ok()); @@ -162,102 +162,98 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { TEST_F(ShapeInferenceTest, ClampAllMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, - matrix_64_48_); + HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampAllScalar) { - auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, f32_, f32_); + auto inferred_status = + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMinScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, matrix_64_48_); + HloOpcode::kClamp, f32_, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMaxScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, f32_); + HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampOperandScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, matrix_64_48_); + HloOpcode::kClamp, matrix_64_48_, f32_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMinMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, f32_); + HloOpcode::kClamp, matrix_64_48_, f32_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMaxMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, f32_, matrix_64_48_); + HloOpcode::kClamp, f32_, f32_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampOperandMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, f32_); + HloOpcode::kClamp, f32_, matrix_64_48_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampBadShapes) { // Type mismatch - ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, s32_, f32_, f32_) - .ok()); - ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, s32_, f32_) - .ok()); - ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, f32_, s32_) - .ok()); - // Dimension mismatch ASSERT_FALSE( - ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, - vector_64_, vector_32_, vector_32_) + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, s32_, f32_, f32_) .ok()); ASSERT_FALSE( - ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, - vector_32_, vector_64_, vector_32_) + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, s32_, f32_) .ok()); ASSERT_FALSE( - ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, - vector_32_, vector_32_, vector_64_) + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, s32_) .ok()); - // Dimension mismatch, where one operand is a scalar + // Dimension mismatch ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, vector_64_, vector_32_, f32_) + HloOpcode::kClamp, vector_64_, vector_32_, vector_32_) .ok()); ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, vector_64_, f32_, vector_32_) + HloOpcode::kClamp, vector_32_, vector_64_, vector_32_) .ok()); ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, vector_64_, vector_32_) + HloOpcode::kClamp, vector_32_, vector_32_, vector_64_) + .ok()); + // Dimension mismatch, where one operand is a scalar + ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, + vector_64_, vector_32_, f32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, + vector_64_, f32_, vector_32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, + vector_64_, vector_32_) .ok()); } TEST_F(ShapeInferenceTest, Complex) { auto complex_shape = [&](const Shape& lhs, const Shape& rhs, const tensorflow::gtl::ArraySlice& bcast) { - return ShapeInference::InferBinaryOpShape(BinaryOperation::BINOP_COMPLEX, - lhs, rhs, bcast); + return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs, + bcast); }; // Inputs must be FP. ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok()); @@ -292,8 +288,8 @@ TEST_F(ShapeInferenceTest, Complex) { } TEST_F(ShapeInferenceTest, VariadicOpTuplify) { - StatusOr result = ShapeInference::InferVariadicOpShape( - VariadicOperation::VAROP_TUPLE, {&s32_, &f32_}); + StatusOr result = + ShapeInference::InferVariadicOpShape(HloOpcode::kTuple, {&s32_, &f32_}); ASSERT_IS_OK(result.status()); ASSERT_TRUE(ShapeUtil::Equal(result.ValueOrDie(), ShapeUtil::MakeTupleShape({s32_, f32_}))); @@ -804,8 +800,8 @@ TEST_F(ShapeInferenceTest, InferConstIndexShape) { TEST_F(ShapeInferenceTest, InferPowShape) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_POW, ten_floats, f32_, {}); + auto inferred_status = ShapeInference::InferBinaryOpShape( + HloOpcode::kPower, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie())); } @@ -813,7 +809,7 @@ TEST_F(ShapeInferenceTest, InferPowShape) { TEST_F(ShapeInferenceTest, InferCompareShapeEq) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_EQ, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kEq, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -822,7 +818,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeEq) { TEST_F(ShapeInferenceTest, InferCompareShapeGe) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_GE, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kGe, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -831,7 +827,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeGe) { TEST_F(ShapeInferenceTest, InferCompareShapeGt) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_GT, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kGt, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -840,7 +836,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeGt) { TEST_F(ShapeInferenceTest, InferCompareShapeLe) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_LE, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kLe, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -849,7 +845,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeLe) { TEST_F(ShapeInferenceTest, InferCompareShapeLt) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_LT, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kLt, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -858,7 +854,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeLt) { TEST_F(ShapeInferenceTest, InferCompareShapeNe) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_NE, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kNe, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -1111,22 +1107,22 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) { const Shape vec8 = ShapeUtil::MakeShape(F32, {8}); const Shape vec16 = ShapeUtil::MakeShape(F32, {16}); - auto inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, mat, vec8, {1}); + auto inferred_status_match = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {1}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat)); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, mat, vec8, {0}); + auto inferred_status_mismatch = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {0}); ASSERT_FALSE(inferred_status_mismatch.ok()); - inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, mat, vec16, {0}); + inferred_status_match = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {0}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat)); - inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, mat, vec16, {1}); + inferred_status_mismatch = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {1}); ASSERT_FALSE(inferred_status_mismatch.ok()); } @@ -1138,17 +1134,17 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) { const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8}); auto inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, cube, matrix8_4, {1, 2}); + HloOpcode::kAdd, cube, matrix8_4, {1, 2}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, cube, matrix16_4, {0, 2}); + HloOpcode::kAdd, cube, matrix16_4, {0, 2}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, cube, matrix16_8, {0, 1}); + HloOpcode::kAdd, cube, matrix16_8, {0, 1}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); } @@ -1162,43 +1158,43 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8}); // "magical" broadcast rejected - auto inferred_status_error1 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, vec8, {}); + auto inferred_status_error1 = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {}); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), HasSubstr("Automatic")); // broadcast_dimension out of bounds for tensor's rank - auto inferred_status_error2 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, vec8, {3}); + auto inferred_status_error2 = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {3}); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), ContainsRegex("Broadcast dimension number .* too large")); // broadcast_dimension doesn't match corresponding dimension - auto inferred_status_error3 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, vec8, {0}); + auto inferred_status_error3 = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {0}); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), HasSubstr("Broadcast dimension 0 mismatch")); // broadcast_dimensions list too long auto inferred_status_error4 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, matrix8_4, {0, 1, 2}); + HloOpcode::kAdd, tensor, matrix8_4, {0, 1, 2}); ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT(inferred_status_error4.status().error_message(), HasSubstr("broadcast_dimensions has to match")); // there's a dimension above the rank of the tensor auto inferred_status_error5 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, matrix8_4, {3, 0}); + HloOpcode::kAdd, tensor, matrix8_4, {3, 0}); ASSERT_FALSE(inferred_status_error5.ok()); ASSERT_THAT(inferred_status_error5.status().error_message(), ContainsRegex("dimension number .* too large")); // broadcasting dimensions don't match in this order auto inferred_status_error6 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, matrix8_4, {2, 1}); + HloOpcode::kAdd, tensor, matrix8_4, {2, 1}); ASSERT_FALSE(inferred_status_error6.ok()); ASSERT_THAT(inferred_status_error6.status().error_message(), HasSubstr("dimension 0 mismatch")); @@ -1207,13 +1203,13 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { // in a proper (strictly increasing) order, even if the lower-rank array // matches the higher-rank array in many different ways. auto inferred_status_error7 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {0, 0}); + HloOpcode::kAdd, tensor8_8_8, matrix8_8, {0, 0}); ASSERT_FALSE(inferred_status_error7.ok()); ASSERT_THAT(inferred_status_error7.status().error_message(), HasSubstr("dimensions order is wrong")); auto inferred_status_error8 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {1, 0}); + HloOpcode::kAdd, tensor8_8_8, matrix8_8, {1, 0}); ASSERT_FALSE(inferred_status_error8.ok()); ASSERT_THAT(inferred_status_error8.status().error_message(), HasSubstr("dimensions order is wrong")); @@ -1315,7 +1311,7 @@ TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) { ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT( inferred_status_error4.status().error_message(), - HasSubstr("Expected non-tuple argument for operand of concatenation")); + HasSubstr("Expected array argument for operand of concatenation")); const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32}); auto inferred_status_error5 = ShapeInference::InferConcatOpShape( @@ -1391,7 +1387,7 @@ TEST_F(ShapeInferenceTest, ReverseInvalidDimension) { ShapeInference::InferReverseShape(tuple_shape, {0}); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), - HasSubstr("Expected non-tuple argument")); + HasSubstr("Expected array argument")); } TEST_F(ShapeInferenceTest, Call) { @@ -1690,7 +1686,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Expected non-tuple argument for input")) + HasSubstr("Expected array argument for input")) << statusor.status(); } @@ -1704,7 +1700,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Expected non-tuple argument for gather indices")) + HasSubstr("Expected array argument for gather indices")) << statusor.status(); } diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index c4d01562c4e32225ebb984d8fcd93ec3fa86e403..4c5038a009ba5da4172129980014913f3f4418f4 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -22,8 +22,12 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/notification.h" + +using ::tensorflow::strings::StrCat; namespace xla { /* static */ tensorflow::mutex @@ -36,8 +40,73 @@ TransferManager::GetPlatformTransferManagers() { return r; } +StatusOr> TransferManager::TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer) { + StatusOr> ret; + se::Stream* substream = stream->GetOrCreateSubStream(); + auto cleanup = tensorflow::gtl::MakeCleanup( + [&]() { stream->ReturnSubStream(substream); }); + + tensorflow::Notification n; + TransferLiteralFromDevice(substream, device_buffer, + [&](StatusOr> arg) { + ret = std::move(arg); + n.Notify(); + }); + n.WaitForNotification(); + return ret; +} + +Status TransferManager::TransferLiteralToDevice( + se::Stream* stream, const LiteralSlice& literal, + const ShapedBuffer& device_buffer) { + // Implement the synchronous version by waiting on the asynchronous version. + // Use a substream so that if we are called from a HostCallback we don't + // deadlock. + se::Stream* substream = stream->GetOrCreateSubStream(); + auto cleanup = tensorflow::gtl::MakeCleanup( + [&]() { stream->ReturnSubStream(substream); }); + TF_RETURN_IF_ERROR( + TransferLiteralToDeviceAsync(substream, literal, device_buffer)); + return substream->BlockHostUntilDone(); +} + +StatusOr> TransferManager::TransferArrayFromDevice( + se::Stream* stream, const Shape& shape, + const se::DeviceMemoryBase& source) { + // Implement the synchronous version by waiting on the asynchronous version. + // Use a substream so that if we are called from a HostCallback we don't + // deadlock. + StatusOr> ret; + se::Stream* substream = stream->GetOrCreateSubStream(); + auto cleanup = tensorflow::gtl::MakeCleanup( + [&]() { stream->ReturnSubStream(substream); }); + + tensorflow::Notification n; + TransferArrayFromDevice(substream, shape, source, + [&](StatusOr> arg) { + ret = std::move(arg); + n.Notify(); + }); + n.WaitForNotification(); + return ret; +} + Status TransferManager::TransferArrayToDevice( - se::StreamExecutor* executor, const LiteralSlice& literal, + se::Stream* stream, const LiteralSlice& literal, + const se::DeviceMemoryBase& dest) { + // Implement the synchronous version by waiting on the asynchronous version. + // Use a substream so that if we are called from a HostCallback we don't + // deadlock. + se::Stream* substream = stream->GetOrCreateSubStream(); + auto cleanup = tensorflow::gtl::MakeCleanup( + [&]() { stream->ReturnSubStream(substream); }); + TF_RETURN_IF_ERROR(TransferArrayToDeviceAsync(substream, literal, dest)); + return substream->BlockHostUntilDone(); +} + +Status TransferManager::TransferArrayToDeviceAsync( + se::Stream* stream, const LiteralSlice& literal, const se::DeviceMemoryBase& dest) { const Shape on_device_shape = HostShapeToDeviceShape(literal.shape()); TF_RET_CHECK(ShapeUtil::IsArray(on_device_shape)) @@ -51,28 +120,32 @@ Status TransferManager::TransferArrayToDevice( dest.size(), GetByteSizeRequirement(on_device_shape)); } ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape, - executor->platform(), executor->device_ordinal()); + stream->parent()->platform(), + stream->parent()->device_ordinal()); shaped_buffer.set_buffer(dest, /*index=*/{}); - return TransferLiteralToDevice(executor, literal, shaped_buffer); + return TransferLiteralToDevice(stream, literal, shaped_buffer); } -StatusOr> TransferManager::TransferArrayFromDevice( - se::StreamExecutor* executor, const Shape& shape, - const se::DeviceMemoryBase& source) { - TF_RET_CHECK(ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) - << "Shape " << ShapeUtil::HumanString(shape) - << " has a differently shaped representation on-device: " - << ShapeUtil::HumanString(HostShapeToDeviceShape(shape)); +void TransferManager::TransferArrayFromDevice( + se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source, + std::function>)> done) { + if (!ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) { + auto error = StrCat("Shape ", ShapeUtil::HumanString(shape), + " has a differently shaped representation on-device: ", + ShapeUtil::HumanString(HostShapeToDeviceShape(shape))); + return done(FailedPrecondition("%s", error.c_str())); + } if (source.size() < GetByteSizeRequirement(shape)) { - return FailedPrecondition( - "Allocation on device not large enough for array: " - "%lld < %lld", - source.size(), GetByteSizeRequirement(shape)); + return done( + FailedPrecondition("Allocation on device not large enough for array: " + "%lld < %lld", + source.size(), GetByteSizeRequirement(shape))); } ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape, - executor->platform(), executor->device_ordinal()); + stream->parent()->platform(), + stream->parent()->device_ordinal()); shaped_buffer.set_buffer(source, /*index=*/{}); - return TransferLiteralFromDevice(executor, shaped_buffer); + return TransferLiteralFromDevice(stream, shaped_buffer, std::move(done)); } /* static */ void TransferManager::RegisterTransferManager( @@ -108,10 +181,14 @@ StatusOr> TransferManager::TransferArrayFromDevice( } Status TransferManager::WriteTupleIndexTables( - se::StreamExecutor* executor, const ShapedBuffer& device_buffer) { - VLOG(2) << "Writing tuple index tables for " << device_buffer; + se::Stream* stream, const ShapedBuffer& device_buffer) { + TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer)); + return stream->BlockHostUntilDone(); +} - TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); +Status TransferManager::WriteTupleIndexTablesAsync( + se::Stream* stream, const ShapedBuffer& device_buffer) { + VLOG(2) << "Writing tuple index tables for " << device_buffer; return ShapeUtil::ForEachSubshapeWithStatus( device_buffer.on_device_shape(), @@ -129,7 +206,7 @@ Status TransferManager::WriteTupleIndexTables( elements.push_back(device_buffer.buffer(element_index)); element_index.pop_back(); } - return WriteSingleTupleIndexTable(executor, elements, device_subshape, + return WriteSingleTupleIndexTable(stream, elements, device_subshape, &device_memory); } @@ -138,26 +215,20 @@ Status TransferManager::WriteTupleIndexTables( } Status TransferManager::TransferBufferFromDevice( - se::StreamExecutor* executor, const se::DeviceMemoryBase& source, - int64 size, void* destination) { + se::Stream* stream, const se::DeviceMemoryBase& source, int64 size, + void* destination) { if (source.size() < size) { return FailedPrecondition( "Source allocation on device not large enough for data tranfer: " "%lld < %lld", source.size(), size); } - auto copy_status = executor->SynchronousMemcpyD2H(source, size, destination); - if (!copy_status.ok()) { - return AddStatus( - Status(static_cast(copy_status.code()), - copy_status.error_message()), - "failed transfer from device to buffer"); - } + stream->ThenMemcpy(destination, source, size); return Status::OK(); } Status TransferManager::TransferBufferToDevice( - se::StreamExecutor* executor, int64 size, const void* source, + se::Stream* stream, int64 size, const void* source, se::DeviceMemoryBase* destination) { if (destination->size() < size) { return FailedPrecondition( @@ -165,13 +236,7 @@ Status TransferManager::TransferBufferToDevice( "%lld < %lld", destination->size(), size); } - auto copy_status = executor->SynchronousMemcpyH2D(source, size, destination); - if (!copy_status.ok()) { - return AddStatus( - Status(static_cast(copy_status.code()), - copy_status.error_message()), - "failed transfer of buffer to device"); - } + stream->ThenMemcpy(destination, source, size); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 43a8092b06fba0e2495bce0ee1a309c85a908273..e384359642a8fe09e0b8516e342a56259912922a 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -52,30 +52,65 @@ class TransferManager { return host_shape; } - // Returns a literal containing the data held in the given ShapedBuffer. - // using the provided executor. The optional literal_shape will be the shape - // for the literal. The shape of the ShapedBuffer and - // DeviceShape(literal_shape) must be compatible, but need not have the same - // layout. + // Returns a literal containing the data held in the given ShapedBuffer + // using the provided executor. This operation is performed synchronously + // without waiting for any other operation on a stream to complete. + // + // This function should be avoided in favor of the asynchronous version below. virtual StatusOr> TransferLiteralFromDevice( - se::StreamExecutor* executor, const ShapedBuffer& device_buffer) = 0; + se::Stream* stream, const ShapedBuffer& device_buffer); + + // Begins transferring a literal containing the data held in the given + // ShapedBuffer using the provided executor. + // + // This operation is performed asynchronously on the given stream. It returns + // once the transfer is enqueued. 'done' is invoked with the result when + // complete. + // + // device_buffer is copied by reference and must live at least until done() is + // invoked. + virtual void TransferLiteralFromDevice( + se::Stream* stream, const ShapedBuffer& device_buffer, + std::function>)> done) = 0; // Transfers the given literal into the previously allocated device memory // represented by the given ShapedBuffer using the given executor. The shape // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, - // but need not have the same layout - virtual Status TransferLiteralToDevice(se::StreamExecutor* executor, + // but need not have the same layout. + // + // This operation is performed synchronously without waiting for any other + // operation on a stream to complete. This function should be avoided in favor + // of the asynchronous version below. + virtual Status TransferLiteralToDevice(se::Stream* stream, const LiteralSlice& literal, - const ShapedBuffer& device_buffer) = 0; + const ShapedBuffer& device_buffer); + + // Transfers the given literal into the previously allocated device memory + // represented by the given ShapedBuffer using the given executor. The shape + // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, + // but need not have the same layout. + // + // This operation is performed asynchronously on the given stream. It returns + // once the transfer is enqueued. + virtual Status TransferLiteralToDeviceAsync( + se::Stream* stream, const LiteralSlice& literal, + const ShapedBuffer& device_buffer) = 0; // Convenience methods for transferring an array to or from the device at a // known address. This avoids having to construct a ShapedBuffer just to // transfer an array at a known address. - Status TransferArrayToDevice(se::StreamExecutor* executor, - const LiteralSlice& literal, + Status TransferArrayToDevice(se::Stream* stream, const LiteralSlice& literal, const se::DeviceMemoryBase& dest); + void TransferArrayFromDevice( + se::Stream* stream, const Shape& shape, + const se::DeviceMemoryBase& source, + std::function>)> done); + + Status TransferArrayToDeviceAsync(se::Stream* stream, + const LiteralSlice& literal, + const se::DeviceMemoryBase& dest); StatusOr> TransferArrayFromDevice( - se::StreamExecutor* executor, const Shape& shape, + se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source); // Transfers the given literal into the Infeed interface of the device, @@ -96,8 +131,10 @@ class TransferManager { // Given an allocated ShapedBuffer, constructs the tuple index table(s) in // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the // ShapedBuffer is array-shaped this method does nothing. - Status WriteTupleIndexTables(se::StreamExecutor* executor, + Status WriteTupleIndexTables(se::Stream* stream, const ShapedBuffer& device_buffer); + Status WriteTupleIndexTablesAsync(se::Stream* stream, + const ShapedBuffer& device_buffer); // Determines the byte size requirement for the given shape on the underlying // architecture. This will be used to allocate an appropriately sized memory @@ -144,7 +181,7 @@ class TransferManager { // 'destination' buffer. // // size is the size to transfer to destination in bytes. - virtual Status TransferBufferFromDevice(se::StreamExecutor* executor, + virtual Status TransferBufferFromDevice(se::Stream* stream, const se::DeviceMemoryBase& source, int64 size, void* destination); @@ -152,15 +189,15 @@ class TransferManager { // destination of the device. // // size is the size to transfer from source in bytes. - virtual Status TransferBufferToDevice(se::StreamExecutor* executor, - int64 size, const void* source, + virtual Status TransferBufferToDevice(se::Stream* stream, int64 size, + const void* source, se::DeviceMemoryBase* destination); // Writes the given device-memory pointers in 'elements' to the given region // to construct a tuple index table in the platform-specific tuple // representation. virtual Status WriteSingleTupleIndexTable( - se::StreamExecutor* executor, + se::Stream* stream, tensorflow::gtl::ArraySlice elements, const Shape& shape, se::DeviceMemoryBase* region) = 0; diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index ba16dc640e2d2974eab4fc8b134a6e33c03e3b85..49e1f873192f800056a2272f7d4f698898b0f8a1 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -178,7 +178,6 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { auto new_conv = HloInstruction::CreateConvolve( convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums); - convolution.SetupDerivedInstruction(new_conv.get()); TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( &convolution, std::move(new_conv))); diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 3139801ea3130324f48d728dc6f739f709e55911..cccb8f2fbb0266bbf1f40b09170938a1e5d3e78d 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -176,7 +176,7 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build(mul)); HloInstruction* call = module->OutlineExpressionFromComputation( - {add, sub, mul}, "", entry_computation); + {add, sub, mul}, "entry", entry_computation); EXPECT_EQ(call, entry_computation->root_instruction()); HloComputation* callee_computation = call->to_apply(); // The arguments to the call should be const1, const2, and const3. diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index bb634e6573ffceeaa66e0ac9141fe7e3a39ed602..eb6d1ada6b553f998fe06917dfdf0b5092cd79cd 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -723,15 +723,16 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( return false; } if (user->opcode() == HloOpcode::kFusion) { - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && - user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - // Loop fusion with kDynamicUpdateSlice fused root. - // - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root at operand - // index 0. - return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0); + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + if (user->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + // Loop fusion with kDynamicUpdateSlice fused root. + // + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root at operand + // index 0. + return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0); + } } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && user->fused_expression_root()->opcode() == HloOpcode::kAdd) { // Output fusion with kAdd fused root. @@ -789,8 +790,12 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( return param_uses.size() == 1 && param_uses[0].first == callee_root && callee_root->IsElementwiseOnOperand(param_uses[0].second); } - // Check if 'user' is element-wise. - return user->IsElementwise(); + // Loop fusions that contain transposing copies won't reach here as they have + // different layouts, which fails the check in the beginning of this function. + // + // Multi-output fusion will fail the check here as tuples are not considered + // an elementwise operation. + return user->IsElementwiseOnOperand(user->operand_index(operand)); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index f558316b05b168a6f100e8ef69adfd9dbc023102..5734f284071944bc22011405898cf86f33dc48d7 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -1148,5 +1148,30 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { call, {})); } +TEST_F(CanShareOperandBufferWithUserTest, LoopFusionWithElementwiseOperand) { + Shape full_shape = ShapeUtil::MakeShape(F32, {16, 32}); + Shape broadcast_shape = ShapeUtil::MakeShape(F32, {16}); + + auto builder = HloComputation::Builder(TestName() + "_fusion"); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, full_shape, "full")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, broadcast_shape, "small")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(full_shape, param1, {0})); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + full_shape, HloOpcode::kAdd, param0, broadcast)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, broadcast}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {})); + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.cc b/tensorflow/compiler/xla/service/tuple_simplifier.cc index d668855084a884518b338cdf396a9330b9f43a2b..77bdcc9de0d830991208a1db271d009bccaf550e 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier.cc @@ -30,10 +30,17 @@ limitations under the License. namespace xla { +TupleSimplifier::TupleSimplifier(bool exclude_entry_computation) : + exclude_entry_computation_(exclude_entry_computation) {} + StatusOr TupleSimplifier::Run(HloModule* module) { // Initially add all GTE and Tuple instructions to the worklist. std::queue worklist; for (auto* computation : module->computations()) { + if (exclude_entry_computation_ && + computation == module->entry_computation()) { + continue; + } for (auto* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kTuple || instruction->opcode() == HloOpcode::kGetTupleElement) { @@ -69,7 +76,6 @@ StatusOr TupleSimplifier::Run(HloModule* module) { // Tuple // HloInstruction* top_tuple = nullptr; - HloInstruction* first_gte = nullptr; bool can_simplify = true; for (int64 operand_number = 0; operand_number < instruction->operand_count(); ++operand_number) { @@ -79,17 +85,10 @@ StatusOr TupleSimplifier::Run(HloModule* module) { can_simplify = false; break; } - if (first_gte == nullptr) { - first_gte = operand; - } else if (!first_gte->has_compatible_sharding(operand)) { - can_simplify = false; - break; - } if (top_tuple == nullptr) { top_tuple = operand->mutable_operand(0); if (!ShapeUtil::Compatible(top_tuple->shape(), - instruction->shape()) || - !instruction->has_compatible_sharding(top_tuple)) { + instruction->shape())) { can_simplify = false; break; } @@ -118,14 +117,12 @@ StatusOr TupleSimplifier::Run(HloModule* module) { HloInstruction* element_source = instruction->mutable_operand(0)->mutable_operand( instruction->tuple_index()); - if (instruction->has_compatible_sharding(element_source)) { - changed = true; - TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source)); - for (HloInstruction* user : element_source->users()) { - if (user->opcode() == HloOpcode::kTuple || - user->opcode() == HloOpcode::kGetTupleElement) { - worklist.push(user); - } + changed = true; + TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source)); + for (HloInstruction* user : element_source->users()) { + if (user->opcode() == HloOpcode::kTuple || + user->opcode() == HloOpcode::kGetTupleElement) { + worklist.push(user); } } } diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h index e5e9b10b5bf3f452d1bfec476b8d5c7d74c4f4e8..750950188312c5077d487f2feef0606f07839432 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.h +++ b/tensorflow/compiler/xla/service/tuple_simplifier.h @@ -27,13 +27,20 @@ namespace xla { // the module. class TupleSimplifier : public HloPassInterface { public: - TupleSimplifier() {} + TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {} + explicit TupleSimplifier(bool exclude_entry_computation); ~TupleSimplifier() override {} tensorflow::StringPiece name() const override { return "tuple-simplifier"; } // Run tuple simplification on the given computation. Returns whether the // computation was changed. StatusOr Run(HloModule* module) override; + + private: + // When set, this pipeline stage will perform optimization of all computations + // apart from the module's entry computation. This is used by Graphcore's + // backend. + bool exclude_entry_computation_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc index ca9ae91281fce5ee061d066fc3e538dbbc09f6b3..d3635eae81ec7017f9bf6a69250d10716309c9ec 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc @@ -42,6 +42,12 @@ class TupleSimplifierTest : public HloTestBase { TF_ASSERT_OK(changed_status.status()); EXPECT_EQ(change_expected, changed_status.ValueOrDie()); } + void Run(HloModule* module, bool change_expected, bool exclude_entry) { + TupleSimplifier simplifier(exclude_entry); + auto changed_status = simplifier.Run(module); + TF_ASSERT_OK(changed_status.status()); + EXPECT_EQ(change_expected, changed_status.ValueOrDie()); + } const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( @@ -211,5 +217,76 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) { EXPECT_THAT(computation->root_instruction(), tuple); } +TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) { + // Verify that the root computation can be excluded + auto module = CreateNewModule(); + + HloInstruction* p0; + HloInstruction* p1; + HloComputation* c0; + HloComputation* c1; + HloComputation* entry; + + { + HloComputation::Builder builder(TestName() + "_1"); + p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "param")); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 1)); + HloInstruction* gte2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 2)); + + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2})); + + c0 = module->AddEmbeddedComputation(builder.Build()); + } + { + HloComputation::Builder builder(TestName() + "_2"); + p1 = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "param")); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 1)); + HloInstruction* gte2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 2)); + + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2})); + + c1 = module->AddEmbeddedComputation(builder.Build()); + } + { + HloComputation::Builder builder(TestName() + "_Entry"); + HloInstruction* tuple_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "param")); + HloInstruction* call0 = builder.AddInstruction( + HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c0)); + HloInstruction* call1 = builder.AddInstruction( + HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c1)); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, call0, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, call1, 1)); + HloInstruction* tuple0 = + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + HloInstruction* gte2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 0)); + HloInstruction* gte3 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 1)); + + builder.AddInstruction(HloInstruction::CreateTuple({gte2, gte3})); + + entry = module->AddEntryComputation(builder.Build()); + } + + Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/ true); + + EXPECT_THAT(c0->root_instruction(), p0); + EXPECT_THAT(c1->root_instruction(), p1); + EXPECT_THAT(entry->instruction_count(), 9); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/versioned_computation_handle.h b/tensorflow/compiler/xla/service/versioned_computation_handle.h deleted file mode 100644 index 5732a56caffa31dde52dff5c2775f9fde0cacfbd..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/versioned_computation_handle.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_ - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace xla { - -// A data structure encapsulating a ComputationHandle and version value of that -// computation. This object is used to unambiguously refer to a particular -// computation in the service. -struct VersionedComputationHandle { - // A version value unambiguously specifying the state of the computation at a - // particular point in time as it is being built. This value is the - // ComputationDataHandle of the current root instruction. - using Version = int64; - - ComputationHandle handle; - Version version; - - string ToString() const; - bool operator==(const VersionedComputationHandle& other) const { - return (handle.handle() == other.handle.handle()) && - (version == other.version); - } - bool operator<(const VersionedComputationHandle& other) const { - return ((handle.handle() < other.handle.handle()) || - ((handle.handle() == other.handle.handle()) && - (version < other.version))); - } -}; - -std::ostream& operator<<(std::ostream& out, - const VersionedComputationHandle& versioned_handle); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_ diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc index aa40b5cb264803097f52966d6f61f1f41b6b3017..44b0ec5cd4c1d406467007fcc530e919d602c438 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc @@ -32,11 +32,11 @@ StatusOr ZeroSizedHloElimination::Run(HloModule* module) { for (HloComputation* comp : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { if (instruction->HasSideEffect() || - ShapeUtil::IsTuple(instruction->shape())) { + !ShapeUtil::IsArray(instruction->shape())) { continue; } if (comp->IsRemovable(instruction) && - ShapeUtil::HasZeroElements(instruction->shape())) { + ShapeUtil::IsZeroElementArray(instruction->shape())) { TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction( instruction, HloInstruction::CreateConstant( Literal::CreateFromShape(instruction->shape())))); diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 5b14953ebb243da7b9be6eafd46160db8bc62707..4aacc87b78e2c271829cdf397cd69bfb490125b8 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -47,6 +47,9 @@ struct ShapeTreeNode { // Children of this node, as indices into the container's nodes_ array. std::vector children; + // Tells whether this is a leaf node. + bool is_leaf = true; + explicit ShapeTreeNode(ShapeIndex index) : ShapeTreeNode(std::move(index), T()) {} ShapeTreeNode(ShapeIndex index, T data) @@ -102,8 +105,8 @@ class ShapeTree { // Returns the data element associated with the array in the shape at the // given index (see ShapeUtil::GetSubshape for how indexes are defined). - const T& element(const ShapeIndex& index) const; - T* mutable_element(const ShapeIndex& index); + const T& element(ShapeIndexView index) const; + T* mutable_element(ShapeIndexView index); // Return the shape represented with this ShapeTree. const Shape& shape() const { return *shape_; } @@ -122,9 +125,7 @@ class ShapeTree { // Returns true if the node at the given index is a leaf node (an array // shape). - bool IsLeaf(const ShapeIndex& index) const { - return Lookup(index)->children.empty(); - } + bool IsLeaf(ShapeIndexView index) const { return Lookup(index)->is_leaf; } ShapeTree(const ShapeTree&) = default; ShapeTree& operator=(const ShapeTree&) = default; @@ -210,12 +211,12 @@ class ShapeTree { // Returns an iterator pointing to the given ShapeIndex. // REQUIRES: index must exist in the ShapeTree. - iterator find(const ShapeIndex& index) { + iterator find(ShapeIndexView index) { Node* element = Lookup(index); return iterator(&nodes_, typename std::vector::iterator(element), /*iterate_leaves_only=*/false); } - const_iterator find(const ShapeIndex& index) const { + const_iterator find(ShapeIndexView index) const { Node* element = Lookup(index); return iterator(&nodes_, typename std::vector::const_iterator(element), @@ -284,8 +285,8 @@ class ShapeTree { static Status ForEachMutableHelper(const Fn& func, std::vector* nodes); // Return the tree node at the given index. - Node* Lookup(const ShapeIndex& index); - const Node* Lookup(const ShapeIndex& index) const; + Node* Lookup(ShapeIndexView index); + const Node* Lookup(ShapeIndexView index) const; // The nodes in this shape tree. std::vector nodes_; @@ -311,16 +312,14 @@ class ShapeTreeIterator : nodes_(nodes), node_(std::move(node)), iterate_leaves_only_(iterate_leaves_only) { - while (iterate_leaves_only && node_ != nodes_->end() && - !node_->children.empty()) { + while (iterate_leaves_only && node_ != nodes_->end() && !node_->is_leaf) { ++node_; } } ShapeTreeIterator& operator++() { ++node_; - while (iterate_leaves_only_ && node_ != nodes_->end() && - !node_->children.empty()) { + while (iterate_leaves_only_ && node_ != nodes_->end() && !node_->is_leaf) { ++node_; } return *this; @@ -333,8 +332,7 @@ class ShapeTreeIterator ShapeTreeIterator& operator--() { --node_; - while (iterate_leaves_only_ && node_ > nodes_->begin() && - !node_->children.empty()) { + while (iterate_leaves_only_ && node_ > nodes_->begin() && !node_->is_leaf) { --node_; } return *this; @@ -358,7 +356,7 @@ class ShapeTreeIterator ContainerType* nodes_; IteratorType node_; // True if we should not include interior nodes in our walk. - bool iterate_leaves_only_; + const bool iterate_leaves_only_; }; template @@ -379,6 +377,7 @@ void ShapeTree::InitChildren(const Shape& shape, const T& init_value, if (ShapeUtil::IsTuple(shape)) { const int64 size = ShapeUtil::TupleElementCount(shape); node->children.reserve(size); + node->is_leaf = false; ShapeIndex shape_index = node->data.first; shape_index.push_back(0); for (int i = 0; i < size; ++i) { @@ -395,6 +394,7 @@ void ShapeTree::InitChildren(const Shape& shape, Node* node) { if (ShapeUtil::IsTuple(shape)) { const int64 size = ShapeUtil::TupleElementCount(shape); node->children.reserve(size); + node->is_leaf = false; ShapeIndex shape_index = node->data.first; shape_index.push_back(0); for (int i = 0; i < size; ++i) { @@ -463,17 +463,17 @@ ShapeTree::ShapeTree(const std::shared_ptr& shape, } template -const T& ShapeTree::element(const ShapeIndex& index) const { +const T& ShapeTree::element(ShapeIndexView index) const { return Lookup(index)->data.second; } template -T* ShapeTree::mutable_element(const ShapeIndex& index) { +T* ShapeTree::mutable_element(ShapeIndexView index) { return &Lookup(index)->data.second; } template -internal::ShapeTreeNode* ShapeTree::Lookup(const ShapeIndex& index) { +internal::ShapeTreeNode* ShapeTree::Lookup(ShapeIndexView index) { Node* node = &nodes_[0]; for (const int64 i : index) { CHECK_GE(i, 0); @@ -485,7 +485,7 @@ internal::ShapeTreeNode* ShapeTree::Lookup(const ShapeIndex& index) { template const internal::ShapeTreeNode* ShapeTree::Lookup( - const ShapeIndex& index) const { + ShapeIndexView index) const { return const_cast(this)->Lookup(index); } diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index dc5facf1581c07fbb74dfcee95025692938632bd..51de82e95746281ed6e587b545dc933b48ce1ad4 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -116,6 +116,11 @@ TEST_F(ShapeTreeTest, InitValueConstructor) { TestInitValueConstructor(nested_tuple_shape_, 10); } +TEST_F(ShapeTreeTest, EmptyTupleMustHaveNoLeaves) { + ShapeTree shape_tree{ShapeUtil::MakeTupleShape({})}; + EXPECT_EQ(0, shape_tree.leaf_count()); +} + TEST_F(ShapeTreeTest, ArrayShape) { ShapeTree shape_tree{array_shape_}; *shape_tree.mutable_element({}) = 42; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index ce4d0079ee5eb28444509c712ec1a34037dc244a..98c3095499f23a816722430846da7c7cbe2ece67 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/iterator_range.h" @@ -263,6 +264,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( tensorflow::gtl::ArraySlice shapes) { Shape result; result.set_element_type(TUPLE); + result.mutable_tuple_shapes()->Reserve(shapes.size()); for (const auto& shape : shapes) { AppendShapeToTuple(shape, &result); } @@ -363,7 +365,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ bool ShapeUtil::IsNil(const Shape& shape) { - return IsTuple(shape) ? IsEmptyTuple(shape) : HasZeroElements(shape); + return IsEmptyTuple(shape); } /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) { @@ -379,6 +381,13 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return shape.tuple_shapes(index); } +/* static */ int64 ShapeUtil::SubshapeCount(const Shape& shape) { + int64 n = 0; + ForEachSubshape(shape, [&](const Shape& literal_subshape, + const ShapeIndex& index) { ++n; }); + return n; +} + /* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64 start, int64 limit) { TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple)); @@ -413,15 +422,26 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( std::multiplies()); } -/* static */ bool ShapeUtil::HasZeroElements(const Shape& shape) { - return ElementsIn(shape) == 0; +/* static */ int64 ShapeUtil::ElementsInRecursive(const Shape& shape) { + CHECK(IsArray(shape) || IsTuple(shape)); + if (IsArray(shape)) { + return ElementsIn(shape); + } + int64 count = 0; + for (const Shape& element_shape : shape.tuple_shapes()) { + count += ElementsInRecursive(element_shape); + } + return count; +} + +/* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) { + return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0; } /* static */ bool ShapeUtil::IsScalarF32(const Shape& shape) { return shape.element_type() == F32 && Rank(shape) == 0; } - namespace { // Class to memoize the computation of @@ -645,15 +665,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { } /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { - if (IsArray(lhs)) { - return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs); - } else if (lhs.element_type() == TUPLE) { - return rhs.element_type() == TUPLE && - ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), Compatible); - } else { - // Opaque, token, etc types are vacuously compatible. - return true; - } + return CompareShapes(lhs, rhs, /*compare_layouts=*/false); } /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, @@ -903,6 +915,21 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return *return_shape; } +/* static */ StatusOr ShapeUtil::TryGetSubshape( + const Shape& shape, ShapeIndexView index) { + const Shape* return_shape = &shape; + for (auto i : index) { + if (!IsTuple(*return_shape) || i < 0 || + i >= return_shape->tuple_shapes_size()) { + return InvalidArgument( + "Shape index %s not a valid subshape index for tuple with shape %s", + index.ToString().c_str(), shape.DebugString().c_str()); + } + return_shape = &return_shape->tuple_shapes(i); + } + return return_shape; +} + /* static */ Shape* ShapeUtil::GetMutableSubshape(Shape* shape, ShapeIndexView index) { Shape* return_shape = shape; @@ -939,66 +966,9 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { return leaves; } -/* static */ Shape ShapeUtil::StripDegenerateDimensions(const Shape& shape) { - CHECK(IsArray(shape)); - - std::vector dimension_sizes; - std::vector degenerate_dimensions; - for (int64 i = 0; i < shape.dimensions_size(); ++i) { - if (shape.dimensions(i) == 1) { - degenerate_dimensions.push_back(i); - } else { - dimension_sizes.push_back(shape.dimensions(i)); - } - } - - // Construct minor_to_major of stripped shape. The order of the non-degenerate - // dimensions should be preserved from the original shape. First, create - // vector of the non-degenerate dimensions from the original minor_to_major - // array. - std::vector minor_to_major; - for (int64 i : shape.layout().minor_to_major()) { - if (std::find(degenerate_dimensions.begin(), degenerate_dimensions.end(), - i) == degenerate_dimensions.end()) { - minor_to_major.push_back(i); - } - } - - // The dimensions in minor_to_major need to be renumbered to account for the - // degenerate dimensions which have removed. Decrement each dimension number - // once for each degenerate dimension which has a smaller number. - for (int i = 0; i < minor_to_major.size(); ++i) { - int adjustment = 0; - for (int64 dim : degenerate_dimensions) { - if (minor_to_major[i] > dim) { - adjustment++; - } - } - minor_to_major[i] -= adjustment; - } - - { - std::vector dims(minor_to_major.size()); - std::iota(dims.begin(), dims.end(), 0); - DCHECK(minor_to_major.size() == dims.size() && - std::is_permutation(minor_to_major.begin(), minor_to_major.end(), - dims.begin())); - } - Shape stripped_shape; - if (LayoutUtil::IsDenseArray(shape)) { - stripped_shape = MakeShapeWithLayout(shape.element_type(), dimension_sizes, - minor_to_major); - } else if (LayoutUtil::IsSparseArray(shape)) { - stripped_shape = - MakeShapeWithSparseLayout(shape.element_type(), dimension_sizes, - shape.layout().max_sparse_elements()); - } else { - stripped_shape = MakeShape(shape.element_type(), dimension_sizes); - } - - VLOG(10) << "Original_shape: " << HumanStringWithLayout(shape); - VLOG(10) << "Stripped_shape: " << HumanStringWithLayout(stripped_shape); - return stripped_shape; +/* static */ bool ShapeUtil::HasDegenerateDimensions(const Shape& shape) { + CHECK(ShapeUtil::IsArray(shape)); + return ArrayContains(AsInt64Slice(shape.dimensions()), 1); } namespace { diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 3853ada6ba65dbb1ac0754bcf753b4553ec260e7..02e4f41505f16de7369e1dbd712dd0756c3f28e7 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -62,6 +62,8 @@ class ShapeIndex { public: ShapeIndex() = default; ShapeIndex(std::initializer_list init) : indices_(init) {} + template + ShapeIndex(InputIt start, InputIt end) : indices_(start, end) {} bool empty() const { return indices_.empty(); } size_t size() const { return indices_.size(); } @@ -132,6 +134,7 @@ class ShapeIndexView { ++new_begin; return ShapeIndexView(new_begin, end_); } + ShapeIndex ToShapeIndex() const { return ShapeIndex(begin_, end_); } bool operator==(const ShapeIndexView& other) const; bool operator!=(const ShapeIndexView& other) const; @@ -172,8 +175,11 @@ class ShapeUtil { // Precondition: IsArray(shape) static int64 ElementsIn(const Shape& shape); - // Returns true if 'shape' has zero elements. - static bool HasZeroElements(const Shape& shape); + // As ElementsIn(), but recurses through tuples. + static int64 ElementsInRecursive(const Shape& shape); + + // Returns true if 'shape' is an array with zero elements. + static bool IsZeroElementArray(const Shape& shape); // Returns the number of bytes required for an allocation of shape. The // |pointer_size| parameter is used for calculating the size of tuple @@ -333,7 +339,7 @@ class ShapeUtil { // Appends a major dimension to the shape with the given bound. static void AppendMajorDimension(int bound, Shape* shape); - // Returns an empty tuple shape. Can be used to indicate side-effects. + // Returns an empty tuple shape. Can be used as a sentinel Shape value. static Shape MakeNil() { return MakeTupleShape({}); } // Checks whether the shape is initialized. @@ -443,7 +449,7 @@ class ShapeUtil { // Returns true if shape is an empty tuple. static bool IsEmptyTuple(const Shape& shape); - // Returns true if shape is an empty tuple, or is an array with no elements. + // Returns true if shape is the nil shape (an empty tuple). static bool IsNil(const Shape& shape); // Returns the number of elements in the given tuple shape. @@ -454,6 +460,9 @@ class ShapeUtil { // Precondition: IsTuple(shape) && TupleElementCount(shape) > index static const Shape& GetTupleElementShape(const Shape& shape, int64 index); + // Returns the number of elements, recursively, in the given shape. + static int64 SubshapeCount(const Shape& shape); + // Slices tuple elements in the range [start, limit) and returns a new tuple // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32). static Shape SliceTuple(const Shape& tuple, int64 start, int64 limit); @@ -473,8 +482,11 @@ class ShapeUtil { static bool IndexIsValid(const Shape& shape, ShapeIndexView index); // GetSubshape and GetMutableSubshape return a particular nested Shape within - // the given Shape argument. + // the given Shape argument. The non-Try variants check fail if index is + // invalid. static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index); + static StatusOr TryGetSubshape(const Shape& shape, + ShapeIndexView index); static Shape* GetMutableSubshape(Shape* shape, ShapeIndexView index); // Returns whether the given index in the given shape is a leaf element of the @@ -510,25 +522,9 @@ class ShapeUtil { static Status ForEachMutableSubshapeWithStatus( Shape* shape, const MutatingStatusVisitorFunction& func); - // Removes all degenerate dimensions (size one) from the given shape. The - // stripped minor_to_major preserves the relative ordering of non-degenerate - // dimensions. The stripped shape has the property that the underlying - // representation (bits in memory) for the stripped shape is the same as the - // original shape modulo padding. Examples: - // - // input shape: F32 [1, 2, 1], minor_to_major = {0, 1, 2} - // stripped shape: F32 [2], minor_to_major = {0} - // - // input shape: F32 [6, 1, 5], minor_to_major = {2, 0, 1} - // stripped shape: F32 [6, 5], minor_to_major = {1, 0} - // - // input shape: F32 [1, 7, 1, 6, 5, 1], minor_to_major = {0, 2, 5, 4, 3, 1} - // stripped shape: F32 [7, 6, 5], minor_to_major = {0, 2, 1} - // - // input shape: F32 [1, 1], minor_to_major = {0, 1} - // stripped shape: F32 [], minor_to_major = {} - // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape) - static Shape StripDegenerateDimensions(const Shape& shape); + // Returns true if `shape` (which must be an array) with degenerate dimensions + // (dimensions with bound 1). + static bool HasDegenerateDimensions(const Shape& shape); // Permutes the dimensions by the given permutation, so // return_value.dimensions[permutation[i]] = argument.dimensions[i] @@ -714,7 +710,7 @@ class ShapeUtil { tensorflow::gtl::ArraySlice incr, const FnType& visitor_function, bool parallel = false) { - if (ShapeUtil::HasZeroElements(shape)) { + if (ShapeUtil::IsZeroElementArray(shape)) { return Status::OK(); } CHECK_EQ(Rank(shape), base.size()); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index ecdb6532f1d743c7dacc266eeba615e19748ee27..606f7492cead5c3b6772625612fec67296740c7f 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -172,6 +172,41 @@ TEST(ShapeUtilTest, CompatibleIdenticalShapes) { ASSERT_TRUE(ShapeUtil::Compatible(shape1, shape2)); } +TEST(ShapeUtilTest, TokenCompatibility) { + EXPECT_TRUE(ShapeUtil::Compatible(ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeTokenShape())); + EXPECT_FALSE(ShapeUtil::Compatible(ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeShape(F32, {}))); + EXPECT_FALSE(ShapeUtil::Compatible(ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeTokenShape())); + EXPECT_TRUE(ShapeUtil::Compatible( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape()}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape()}))); +} + +TEST(ShapeUtilTest, TokensEqualShapes) { + EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeTokenShape())); + EXPECT_FALSE(ShapeUtil::Equal(ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeShape(F32, {}))); + EXPECT_FALSE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeTokenShape())); + EXPECT_TRUE(ShapeUtil::Equal( + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1})}), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1})}))); + EXPECT_FALSE(ShapeUtil::Equal( + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1})}), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {1, 0})}))); +} + TEST(ShapeUtilTest, CompatibleNotIdenticalShapes) { Shape shape_1 = ShapeUtil::MakeShape(F32, {3, 2}); auto layout_1 = shape_1.mutable_layout(); @@ -329,6 +364,16 @@ TEST(ShapeUtilTest, ByteSizeOfWithPadding) { EXPECT_EQ(15 * 21 * 4, ShapeUtil::ByteSizeOf(shape)); } +TEST(ShapeUtilTest, NilShape) { + EXPECT_TRUE(ShapeUtil::IsNil(ShapeUtil::MakeNil())); + EXPECT_FALSE(ShapeUtil::IsNil(ShapeUtil::MakeShape(F32, {1, 2, 3}))); + EXPECT_FALSE(ShapeUtil::IsNil(ShapeUtil::MakeShape(F32, {0, 1}))); + EXPECT_FALSE(ShapeUtil::IsNil( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {})}))); + EXPECT_FALSE(ShapeUtil::IsNil( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {0})}))); +} + TEST(ShapeUtilTest, NestedTuple) { EXPECT_FALSE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape({}))); EXPECT_FALSE(ShapeUtil::IsNestedTuple( @@ -359,25 +404,30 @@ TEST(ShapeUtilTest, ElementsIn) { EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17}))); } -TEST(ShapeUtilTest, HasZeroElements) { - EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {}))); - EXPECT_EQ(true, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {0}))); - EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1}))); - EXPECT_EQ(false, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1, 1}))); - EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {2}))); - EXPECT_EQ(false, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {2, 1}))); - EXPECT_EQ(false, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {3, 5}))); - EXPECT_EQ(true, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {3, 0, 5}))); - EXPECT_EQ(true, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {0, 3, 0}))); - EXPECT_EQ(false, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1, 3, 5}))); - EXPECT_EQ(false, - ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {13, 17}))); +TEST(ShapeUtilTest, IsZeroElementArray) { + EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {}))); + EXPECT_TRUE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0}))); + EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {1}))); + EXPECT_FALSE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {1, 1}))); + EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {2}))); + EXPECT_FALSE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {2, 1}))); + EXPECT_FALSE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {3, 5}))); + EXPECT_TRUE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {3, 0, 5}))); + EXPECT_TRUE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0, 3, 0}))); + EXPECT_FALSE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {1, 3, 5}))); + EXPECT_FALSE( + ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {13, 17}))); + + EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeNil())); + EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeTupleShape({}))); + EXPECT_FALSE(ShapeUtil::IsZeroElementArray( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {0, 3, 0})}))); } TEST(ShapeUtilTest, SameDimensions) { @@ -742,14 +792,15 @@ TEST(ShapeUtilTest, ReshapeIsBitcast_3x2x2_6x2_Dim1IsMostMinor) { ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1}))); } -TEST(ShapeUtilTest, StripDegenerateDimensions) { - EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::StripDegenerateDimensions( - ShapeUtil::MakeShape(F32, {3, 1, 2})), - ShapeUtil::MakeShape(F32, {3, 2}))); - EXPECT_TRUE(ShapeUtil::Equal( - ShapeUtil::StripDegenerateDimensions( - ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 1, 2}, 10)), - ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 2}, 10))); +TEST(ShapeUtilTest, HasDegenerateDimensions) { + EXPECT_TRUE( + ShapeUtil::HasDegenerateDimensions(ShapeUtil::MakeShape(F32, {3, 1, 2}))); + EXPECT_TRUE( + ShapeUtil::HasDegenerateDimensions(ShapeUtil::MakeShape(F32, {3, 1, 1}))); + EXPECT_FALSE( + ShapeUtil::HasDegenerateDimensions(ShapeUtil::MakeShape(F32, {3, 3, 5}))); + EXPECT_FALSE( + ShapeUtil::HasDegenerateDimensions(ShapeUtil::MakeShape(F32, {3, 0, 5}))); } TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 7f6bbe6f879fd9596601f99f034a0391a71c52f8..b76830f6662ba3bda2b0c64c88cd74f0a13f75f0 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1203,6 +1203,22 @@ xla_test( ], ) +xla_test( + name = "token_hlo_test", + srcs = ["token_hlo_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], + deps = [ + ":client_library_test_base", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + xla_test( name = "call_test", srcs = ["call_test.cc"], @@ -1970,6 +1986,7 @@ xla_test( "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:test", ], ) diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 36a706496918ac8c15780473019e2a8d098ffa22..c3a289ee09cc1ee7b9d705a38c26a3ac7a8a6aa2 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -2758,7 +2758,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), ::testing::ContainsRegex( - "Expected non-opaque argument for lhs of binary operation")); + "Expected array argument for lhs of binary operation")); } XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) { diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 34c86e007beea1cbac04641bdbdab62dc567f13e..3a0f51fc66d65c8684bd607b9e8103559cd4d8d4 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -671,7 +671,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().error_message(), - HasSubstr("op BINOP_ADD with incompatible shapes")); + HasSubstr("op add with incompatible shapes")); } XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { @@ -684,7 +684,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().error_message(), - HasSubstr("op BINOP_ADD with incompatible shapes")); + HasSubstr("op add with incompatible shapes")); } } // namespace diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index a4c8a83eb15f7cc279b6c8f1bf1394c0afb9f7cf..352864502a184237fde600330836fe471a5444f2 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -417,7 +417,22 @@ XLA_TEST_F(ConcatTest, CannotConcatOpaques) { ASSERT_FALSE(computation_status.ok()); EXPECT_THAT( computation_status.status().ToString(), - HasSubstr("Expected non-opaque argument for operand of concatenation")); + HasSubstr("Expected array argument for operand of concatenation")); +} + +// Show that we can't concatenate with tokens. +XLA_TEST_F(ConcatTest, CannotConcatTokens) { + XlaBuilder builder(TestName()); + auto token_shape = ShapeUtil::MakeTokenShape(); + auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1}); + auto x = builder.Parameter(0, r1f32, "x"); + auto y = builder.Parameter(1, token_shape, "y"); + builder.ConcatInDim({x, y}, 0); + StatusOr computation_status = builder.Build(); + ASSERT_FALSE(computation_status.ok()); + EXPECT_THAT( + computation_status.status().ToString(), + HasSubstr("Expected array argument for operand of concatenation")); } XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) { diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 722d882471a41a75c1e5e60f8c1a151b76c7e004..3a885b43893f84fb331572343308130bb06f7e86 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -461,5 +461,26 @@ XLA_TEST_F(ConvertTest, ConvertS64U64) { ComputeAndCompareR1(&builder, unsigned_x, {}); } +XLA_TEST_F(ConvertTest, ConvertBF16F32) { + XlaBuilder builder(TestName()); + + std::vector all_bfloats(1 << 16); + for (int i = 0; i < all_bfloats.size(); ++i) { + all_bfloats[i].value = i; + } + + std::vector expected(all_bfloats.size()); + for (int i = 0; i < expected.size(); ++i) { + expected[i] = (1U << 16) * i; + } + + // Exhaustively test all bf16 to f32 conversions. + xla::XlaOp all_bfloats_bf16 = builder.ConstantR1(all_bfloats); + xla::XlaOp all_bfloats_f32 = + builder.ConvertElementType(all_bfloats_bf16, F32); + xla::XlaOp all_bfloats_u32 = builder.BitcastConvertType(all_bfloats_f32, U32); + ComputeAndCompareR1(&builder, expected, {}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 947959beb144e1509a77ad2f94b8493de46ba6f2..346bb3a3996ee5bf662b0f74dd0c2096efbf5295 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -47,9 +47,9 @@ class ConvolutionTest : public ClientLibraryTestBase { #if XLA_TEST_BACKEND_GPU // XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial // convolution. So relax the absolute error threshold. - ErrorSpec error_spec_ = ErrorSpec(1e-2); + ErrorSpec error_spec_ = ErrorSpec(1e-2, 1e-4); #else - ErrorSpec error_spec_ = ErrorSpec(1e-4); + ErrorSpec error_spec_ = ErrorSpec(1e-4, 1e-4); #endif }; diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc index c960b3c15faf979c23c3c2f6a745de8b359465c2..b151187c4b8f01c5b46ccadf27d2e22a7c902e98 100644 --- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc +++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc @@ -32,9 +32,16 @@ class TrivialCrossReplicaSumTest : public HloTestBase {}; XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { const char* module_str = R"( HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + ENTRY test_computation { p = f32[3] parameter(0) - ROOT crs = f32[3] cross-replica-sum(p) + ROOT crs = f32[3] cross-replica-sum(p), to_apply=add })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); @@ -45,10 +52,17 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { const char* module_str = R"( HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + ENTRY test_computation { p0 = f32[3] parameter(0) p1 = f32[2] parameter(1) - ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1) + ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); @@ -65,10 +79,17 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) { const char* module_str = R"( HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + ENTRY test_computation { p0 = f32[3] parameter(0) p1 = f32[2] constant({10, 20}) - ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1) + ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 49f3a10d227f2f9edfe76405ba13498fe822f8d8..a918c91f07ff0241845df4ef99334020859d8311 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -716,8 +716,10 @@ void BM_DynamicSlice(int num_iters) { .ConsumeValueOrDie(); auto start_indices_literal = Literal::CreateR1({0, 1, 2, 3}); + auto stream = + client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *start_indices_literal, buffer)); + stream.get(), *start_indices_literal, buffer)); std::unique_ptr executable = client diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 143ffbdeb409d91ab6d46d386aa5ff98ebc4ae10..6fefae36958011c918cedc6703289551b00acc80 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -629,8 +629,8 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { client_->ExecuteParallel(computation_instances)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, client_->Transfer(*(result_data[0]))); - EXPECT_TRUE(LiteralTestUtil::Equal( - *result_literal, *Literal::CreateR2({{1, 2, 3}, {7, 8, 9}}))); + LiteralTestUtil::ExpectR2Equal({{1, 2, 3}, {7, 8, 9}}, + *result_literal); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 08ed826c80823efe0af8ce682945fe7e46d267ae..242cc5db11ff2bdf69209df7537216573d8afbf3 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -94,8 +94,7 @@ HloTestBase::HloTestBase(se::Platform* test_platform, /* static */ std::unique_ptr HloTestBase::CreateNewModule(const string& name) { - return MakeUnique(name, VersionedComputationHandle(), - GetModuleConfigForTest()); + return MakeUnique(name, GetModuleConfigForTest()); } /*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index eb3a2ea76a667a2afa2562f01d28f34384b84a21..9009d67cea6840235d63724ef76d777c8f693d33 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -66,6 +66,15 @@ namespace xla { // // For a more detailed example, see "../tests/sample_text_test.cc". class HloTestBase : public ::testing::Test { + public: + // Creates a new HLO module for a test. The module created will have + // TestName() for its name; it will also automatically populate its debug + // options from command-line flags. If you want a fresh HloModule object and + // then add HloComputations to it, it's recommended to use this method in your + // tests. + static std::unique_ptr CreateNewModule( + const string& name = TestName()); + protected: // This uses the interpreter backend as the reference backend and // automatically finds another supported backend as the test backend. If the @@ -80,14 +89,6 @@ class HloTestBase : public ::testing::Test { ~HloTestBase() override {} - // Creates a new HLO module for a test. The module created will have - // TestName() for its name; it will also automatically populate its debug - // options from command-line flags. If you want a fresh HloModule object and - // then add HloComputations to it, it's recommended to use this method in your - // tests. - static std::unique_ptr CreateNewModule( - const string& name = TestName()); - // Populates debug options from command-line flags and adjusts the options for // testing. It is recommended to use this when you need to pass in // DebugOptions, e.g. when creating a module from a string or a file. @@ -184,13 +185,9 @@ class HloTestBase : public ::testing::Test { // 'layout'. void ForceParameterLayout(HloModule* module, int64 param_no, const Layout& layout) { - ASSERT_LT( - param_no, - module->mutable_host_entry_computation_layout()->parameter_count()); - module->mutable_host_entry_computation_layout() - ->mutable_parameter_layout(param_no) - ->ResetLayout(layout); - module->mutable_device_entry_computation_layout() + ASSERT_LT(param_no, + module->mutable_entry_computation_layout()->parameter_count()); + module->mutable_entry_computation_layout() ->mutable_parameter_layout(param_no) ->ResetLayout(layout); } @@ -198,10 +195,7 @@ class HloTestBase : public ::testing::Test { // Convenience method to force the layout of the computation result in a // module. The result layout of 'module' is set to 'layout'. void ForceResultLayout(HloModule* module, const Layout& layout) { - module->mutable_host_entry_computation_layout() - ->mutable_result_layout() - ->ResetLayout(layout); - module->mutable_device_entry_computation_layout() + module->mutable_entry_computation_layout() ->mutable_result_layout() ->ResetLayout(layout); } @@ -209,10 +203,7 @@ class HloTestBase : public ::testing::Test { // Convenience method to clear the layout of the computation result in // 'module'. void ForceClearResultLayout(HloModule* module) { - module->mutable_host_entry_computation_layout() - ->mutable_result_layout() - ->Clear(); - module->mutable_device_entry_computation_layout() + module->mutable_entry_computation_layout() ->mutable_result_layout() ->Clear(); } diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index c8a05c2e9e971d86feb6ff893fcd25c6767af99f..ad1f5b9eed8b5b140100c1fa35dc7d698e3db48b 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -41,14 +41,17 @@ void HloVerifiedTestBase::TearDown() { << "TearDown called more than once; it should be called exactly once."; tear_down_called_ = true; if (module_) { - VerifyModule(); + VerifyModule(module_.get()); + } + for (int i = 0; i < modules_.size(); ++i) { + VerifyModule(modules_.at(i).get()); } HloTestBase::TearDown(); } -void HloVerifiedTestBase::VerifyModule() { - HloVerifier verifier; - xla::StatusOr mutated = verifier.Run(module_.get()); +void HloVerifiedTestBase::VerifyModule(HloModule* module) { + HloVerifier verifier(/*allow_mixed_precision=*/true); + xla::StatusOr mutated = verifier.Run(module); if (!mutated.ok()) { ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); } else { @@ -59,15 +62,20 @@ void HloVerifiedTestBase::VerifyModule() { HloModule& HloVerifiedTestBase::module() { if (!module_) { - module_ = CreateNewModule(); + module_ = HloTestBase::CreateNewModule(); } return *module_; } -void HloVerifiedTestBase::ParseAndVerifyModule( - tensorflow::StringPiece hlo_text) { +HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) { + modules_.emplace_back(HloTestBase::CreateNewModule()); + return modules_.back().get(); +} + +void HloVerifiedTestBase::ParseAndVerifyModule(tensorflow::StringPiece hlo_text, + const HloModuleConfig& config) { CHECK(!module_) << "Called ParseModule when test already has a module."; - TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text)); - VerifyModule(); + TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text, config)); + VerifyModule(module_.get()); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h index e5bb14a8839acbdef8fd2b79bb0f574c46ea3d40..5b28c01c369fa1ae1c7941f5c8139882c4dbed08 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -44,7 +44,8 @@ class HloVerifiedTestBase : public HloTestBase { // Returns the default HloModule, lazily creating it if necessary via // HloTestBase::CreateNewModule(). HloModule& module(); - void ParseAndVerifyModule(tensorflow::StringPiece hlo_text); + void ParseAndVerifyModule(tensorflow::StringPiece hlo_text, + const HloModuleConfig& config = HloModuleConfig()); // Sets the shape-size function used during hlo verification. If this isn't // called, a default ShapeVerifier is used instead. @@ -52,11 +53,23 @@ class HloVerifiedTestBase : public HloTestBase { shape_verifier_ = std::move(shape_verifier); } + // Creates a new module for a test, and stores it in modules_ so it can be + // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent + // creation of unverified modules. + HloModule* CreateNewModule(const string& name = TestName()); + + // It is confusing to store modules created by module() and CreateNewModule() + // in different fields, but it allows us to migrate tests to + // HloVerifiedTestBase more easily, so it's a win because we can verify more + // modules. See b/80488902. private: - std::unique_ptr module_; // Lazily populated. Access via module(). + // Lazily populated. Access via module(). + std::unique_ptr module_; + // Populated by calls to CreateNewModule. + std::vector> modules_; std::unique_ptr shape_verifier_; bool tear_down_called_ = false; - void VerifyModule(); + static void VerifyModule(HloModule* module); }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index 2f46ee0be216d7dabf1c476d3cfb7d528f8ab6a4..082bc34136e004795ce300c66591758f47c665fe 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -124,8 +124,7 @@ class LLVMCompilerTest : public ::testing::Test { static std::unique_ptr CreateNewModule() { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - return MakeUnique(TestName(), VersionedComputationHandle(), - config); + return MakeUnique(TestName(), config); } }; diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 96858c00d6bbe59b673a34e7d5ca261756709596..5a70c2a9ae5f32f27ec012d554c183159a63576c 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -209,13 +209,12 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape())); std::unique_ptr result_literal = ShapedBufferToLiteral(result); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0})); - LiteralTestUtil::ExpectR2Equal( - {{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralSlice(*result_literal, {1})); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {2})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {0})); + LiteralTestUtil::ExpectR2Equal({{10.0f, 20.0f}, {30.0f, 40.0f}}, + LiteralSlice(*result_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {2})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { @@ -238,17 +237,14 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); std::unique_ptr result_literal = ShapedBufferToLiteral(result); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1})); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0, 0})); - LiteralTestUtil::ExpectR2Equal( - {{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralSlice(*result_literal, {0, 1})); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralSlice(*result_literal, {0, 2})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {0, 0})); + LiteralTestUtil::ExpectR2Equal({{10.0f, 20.0f}, {30.0f, 40.0f}}, + LiteralSlice(*result_literal, {0, 1})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {0, 2})); } XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { @@ -273,10 +269,10 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { options, DefaultExecutableRunOptions()); std::unique_ptr result_literal = ShapedBufferToLiteral(result); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0})); - LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {0})); + LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { @@ -319,11 +315,10 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); std::unique_ptr result_literal = ShapedBufferToLiteral(result); - LiteralTestUtil::ExpectR2Equal( - {{56.0f, 46.0f}, {36.0f, 26.0f}}, - LiteralSlice(*result_literal, {0})); - LiteralTestUtil::ExpectR1Equal( - {40.0f, 71.0f, 117.0f}, LiteralSlice(*result_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{56.0f, 46.0f}, {36.0f, 26.0f}}, + LiteralSlice(*result_literal, {0})); + LiteralTestUtil::ExpectR1Equal({40.0f, 71.0f, 117.0f}, + LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { @@ -360,10 +355,10 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); std::unique_ptr result_literal = ShapedBufferToLiteral(result); - LiteralTestUtil::ExpectR2Equal( - {{-1.0, -2.0}, {-3.0, -4}}, LiteralSlice(*result_literal, {0})); - LiteralTestUtil::ExpectR1Equal( - {264.0, 73.0, 133.0}, LiteralSlice(*result_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{-1.0, -2.0}, {-3.0, -4}}, + LiteralSlice(*result_literal, {0})); + LiteralTestUtil::ExpectR1Equal({264.0, 73.0, 133.0}, + LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { @@ -389,18 +384,17 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer}); std::unique_ptr result_0_literal = ShapedBufferToLiteral(result_0); - LiteralTestUtil::ExpectR2Equal( - {{-1.0, -2.0}, {-3.0, -4.0}}, - LiteralSlice(*result_0_literal, {0})); - LiteralTestUtil::ExpectR2Equal( - {{22.0, 6.0}, {8.0, 10}}, LiteralSlice(*result_0_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{-1.0, -2.0}, {-3.0, -4.0}}, + LiteralSlice(*result_0_literal, {0})); + LiteralTestUtil::ExpectR2Equal({{22.0, 6.0}, {8.0, 10}}, + LiteralSlice(*result_0_literal, {1})); ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0}); std::unique_ptr result_1_literal = ShapedBufferToLiteral(result_1); - LiteralTestUtil::ExpectR2Equal( - {{1.0, 2.0}, {3.0, 4.0}}, LiteralSlice(*result_1_literal, {0})); - LiteralTestUtil::ExpectR2Equal( - {{44.0, 12.0}, {16.0, 20}}, LiteralSlice(*result_1_literal, {1})); + LiteralTestUtil::ExpectR2Equal({{1.0, 2.0}, {3.0, 4.0}}, + LiteralSlice(*result_1_literal, {0})); + LiteralTestUtil::ExpectR2Equal({{44.0, 12.0}, {16.0, 20}}, + LiteralSlice(*result_1_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { @@ -447,8 +441,7 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { for (int i = 0; i < kElementCount; ++i) { LiteralTestUtil::ExpectR1Near( - {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), - error_spec_); + {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), error_spec_); } } @@ -547,8 +540,8 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { for (int i = 0; i < kTupleDepth; ++i) { index.push_back(0); } - LiteralTestUtil::ExpectR0Equal( - 165.0, LiteralSlice(*result_literal, index)); + LiteralTestUtil::ExpectR0Equal(165.0, + LiteralSlice(*result_literal, index)); } XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { @@ -753,10 +746,10 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) { ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); std::unique_ptr tuple_literal = ShapedBufferToLiteral(result); - LiteralTestUtil::ExpectR1Equal( - {2.0f, 4.0f, 6.0f}, LiteralSlice(*tuple_literal, {0})); - LiteralTestUtil::ExpectR1Equal( - {1.0f, 2.0f, 3.0f}, LiteralSlice(*tuple_literal, {1})); + LiteralTestUtil::ExpectR1Equal({2.0f, 4.0f, 6.0f}, + LiteralSlice(*tuple_literal, {0})); + LiteralTestUtil::ExpectR1Equal({1.0f, 2.0f, 3.0f}, + LiteralSlice(*tuple_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { @@ -900,8 +893,10 @@ void BM_LocalClientOverhead(int num_iters) { ->AllocateScopedShapedBuffer(shape, &allocator, /*device_ordinal=*/0) .ConsumeValueOrDie(); auto literal = Literal::CreateR2({{0, 0, 0}, {0, 0, 0}}); - ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *literal, buffer)); + auto stream = + client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie(); + ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(stream.get(), *literal, + buffer)); const int kWarmups = 2; @@ -911,11 +906,8 @@ void BM_LocalClientOverhead(int num_iters) { std::unique_ptr executable = executable_status.ConsumeValueOrDie(); - se::Stream stream(executors[client->default_device_ordinal()]); - stream.Init(); - ExecutableRunOptions run_options; - run_options.set_allocator(&allocator).set_stream(&stream); + run_options.set_allocator(&allocator).set_stream(stream.get()); for (int i = 0; i < kWarmups; ++i) { auto result = executable->Run({&buffer}, run_options); diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 7df45bebebdd3eb2e71f27d831a8e2ac9e3b5f7c..3975e9125703ee081d4e84fa8bd27fcbe483ac34 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -488,10 +488,9 @@ TEST_F(MapTest, MapOperantionWithBuildError) { StatusOr computation_status = builder.Build(); ASSERT_TRUE(!computation_status.ok()); - EXPECT_THAT( - computation_status.status().ToString(), - ::testing::HasSubstr("error from: ErrorAdd: Binary op BINOP_ADD with " - "different element types: f32[] and u16[]")); + EXPECT_THAT(computation_status.status().ToString(), + ::testing::HasSubstr("error from: ErrorAdd: Binary op add with " + "different element types: f32[] and u16[]")); } // MapTest disables inline and algsimp. MapTestWithFullOpt runs all diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 3cbb2452fb245b6703d3bcd5771a51f6e30aa593..a42a19af15e87bd58d16294d012ec4db31e90070 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -204,10 +204,10 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { Literal::CreateR0(1.0)), Literal::MakeTupleOwned(Literal::CreateR0(3.0), Literal::CreateR0(4))); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, *Literal::MakeTupleOwned(Literal::CreateR0(42)))); + *Literal::MakeTupleOwned(Literal::CreateR0(42)), *result)); } XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { @@ -233,10 +233,9 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = Literal::CreateR1({1.0, 2.0, 3.0, -1.0}); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); - EXPECT_TRUE(LiteralTestUtil::Equal( - *result, *Literal::CreateR1({0.0, 4.0, 9.0, 1.0}))); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); + LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0, 1.0}, *result); } XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { @@ -267,10 +266,9 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = Literal::CreateR1({1.0, 2.0, 3.0}); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); - EXPECT_TRUE(LiteralTestUtil::Equal( - *result, *Literal::CreateR1({0.0, 4.0, 9.0}))); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); + LiteralTestUtil::ExpectR1Equal({0.0, 4.0, 9.0}, *result); } const char* const kScalarOps = R"( @@ -311,12 +309,12 @@ XLA_TEST_F(MultiOutputFusionTest, HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, *Literal::MakeTupleOwned(Literal::CreateR2({{3, 7}, {11, 15}}), - Literal::CreateR2({{5, 16}, {36, 64}})))); + Literal::CreateR2({{5, 16}, {36, 64}})), + *result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -341,12 +339,12 @@ XLA_TEST_F(MultiOutputFusionTest, HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, *Literal::MakeTupleOwned( - Literal::CreateR2({{6, 8}, {10, 12}}), - Literal::CreateR2({{25, 36}, {49, 64}})))); + *Literal::MakeTupleOwned(Literal::CreateR2({{6, 8}, {10, 12}}), + Literal::CreateR2({{25, 36}, {49, 64}})), + *result)); } XLA_TEST_F(MultiOutputFusionTest, @@ -357,9 +355,9 @@ XLA_TEST_F(MultiOutputFusionTest, c0 = f32[] constant(0) r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add mul = f32[2,2,2]{2,1,0} multiply(p0, p0) - c1 = f32[] constant(5) + c1 = f32[] constant(1.17549e-38) r2 = f32[2]{0} reduce(mul, c1), dimensions={0,2}, to_apply=Max - r3 = f32[2]{0} reduce(mul, c1), dimensions={0,2}, to_apply=Add + r3 = f32[2]{0} reduce(mul, c0), dimensions={0,2}, to_apply=Add ROOT tuple = (f32[2]{0}, f32[2]{0}, f32[2]{0}) tuple(r1, r2, r3) } @@ -372,12 +370,186 @@ XLA_TEST_F(MultiOutputFusionTest, HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); - TF_ASSERT_OK_AND_ASSIGN(auto result, - Execute(std::move(module), {param.get()})); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::MakeTupleOwned(Literal::CreateR1({14, 22}), + Literal::CreateR1({36, 64}), + Literal::CreateR1({66, 138})), + *result)); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionMinorWithExtraOutput)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max + ROOT tuple = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) + tuple(p0, r1, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), + kind=kInput, calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::MakeTupleOwned( + Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}), + Literal::CreateR2({{3, 7}, {11, 15}}), + Literal::CreateR2({{5, 16}, {36, 64}})), + *result)); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionMajorWithExtraOutput)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max + ROOT tuple = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) + tuple(r1, mul, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) fusion(p), + kind=kInput, calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::MakeTupleOwned( + Literal::CreateR2({{6, 8}, {10, 12}}), + Literal::CreateR3({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), + Literal::CreateR2({{25, 36}, {49, 64}})), + *result)); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionScalarWithExtraOutput)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + mul2 = f32[2,2,2]{2,1,0} multiply(p0, c1) + ROOT tuple = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) + tuple(r1, mul, mul2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) fusion(p), + kind=kInput, calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::MakeTupleOwned( + Literal::CreateR1({14, 22}), + Literal::CreateR3({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), + Literal::CreateR3( + {{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})), + *result)); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionNonConstInit)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + init1 = f32[] parameter(1) + init2 = f32[] parameter(2) + r1 = f32[2,2]{1,0} reduce(p0, init1), dimensions={2}, to_apply=Add + r2 = f32[2,2]{1,0} reduce(p0, init2), dimensions={2}, to_apply=Max + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + i = f32[] parameter(1) + j = f32[] parameter(2) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p, i, j), kind=kInput, + calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto init1 = Literal::CreateR0(5); + auto init2 = Literal::CreateR0(6); + std::unique_ptr result = ExecuteNoHloPasses( + std::move(module), {param.get(), init1.get(), init2.get()}); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::MakeTupleOwned( + Literal::CreateR2({{167, 172}, {176, 180}}), + Literal::CreateR2({{6, 6}, {6, 8}})), + *result)); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionDifferentElementTypes)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce (p0: f16[2,2,2]) -> (f32[2,2], f32[2,2], f16[2,2,2]) { + p0 = f16[2,2,2]{2,1,0} parameter(0) + convert = f32[2,2,2]{2,1,0} convert(p0) + c0 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(convert, c0), dimensions={2}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(convert, convert) + c1 = f32[] constant(5) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) + tuple(r1, r2, p0) + } + + ENTRY reduce { + p = f16[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) fusion(p), + kind=kInput, calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3( + {{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}}, + {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}}); + std::unique_ptr result = + ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, *Literal::MakeTupleOwned(Literal::CreateR1({14, 22}), - Literal::CreateR1({36, 64}), - Literal::CreateR1({391, 463})))); + *Literal::MakeTupleOwned( + Literal::CreateR2({{3, 7}, {11, 15}}), + Literal::CreateR2({{5, 16}, {36, 64}}), + Literal::CreateR3({{{Eigen::half(1), Eigen::half(2)}, + {Eigen::half(3), Eigen::half(4)}}, + {{Eigen::half(5), Eigen::half(6)}, + {Eigen::half(7), Eigen::half(8)}}})), + *result)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index dd7c541733634213606b5a7983b59bb1f14bf75c..000535a982fb08af69e7b317501f82ba7f402fb9 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -270,14 +270,22 @@ StatusOr> CreateLiteralForConstrainedUses( switch (use->opcode()) { case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: - if (needs_index != nullptr && - !ShapeUtil::Equal(needs_index->shape(), use->shape())) { - return Unimplemented( - "Conflicting operand generation slice index constraints\n"); + if (needs_index != nullptr) { + auto needs_index_shape = needs_index->shape(); + auto use_shape = use->shape(); + if (needs_index->opcode() == HloOpcode::kDynamicSlice) { + needs_index_shape = needs_index->operand(0)->shape(); + } + if (use->opcode() == HloOpcode::kDynamicSlice) { + use_shape = use->operand(0)->shape(); + } + if (!ShapeUtil::Equal(needs_index_shape, use_shape)) { + return Unimplemented( + "Conflicting operand generation slice index constraints\n"); + } } needs_index = use; break; - case HloOpcode::kReduce: case HloOpcode::kReduceWindow: needs_constant = use; diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8541698576fb8aae1e3528cb618b367f843b8d53 --- /dev/null +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -0,0 +1,216 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class TokenHloTest : public HloTestBase {}; + +XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + + module->AddEntryComputation(builder.Build()); + EXPECT_IS_OK(HloVerifier().Run(module.get()).status()); +} + +XLA_TEST_F(TokenHloTest, TokenTree) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto token0 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + auto token1 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + auto token2 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + builder.AddInstruction( + HloInstruction::CreateGenerateToken({token0, token0, token1, token2})); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + + module->AddEntryComputation(builder.Build()); + EXPECT_IS_OK(HloVerifier().Run(module.get()).status()); +} + +XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); + builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeTokenShape(), "p1")); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("Entry parameter 1 is or contains a token shape")); +} + +XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction(HloInstruction::CreateParameter( + 0, + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {1, 2, 3}), ShapeUtil::MakeTokenShape()}), + "param")); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("Entry parameter 0 is or contains a token shape")); +} + +XLA_TEST_F(TokenHloTest, InvalidTokenRoot) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Entry root is or contains a token shape")); +} + +XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); + builder.AddInstruction(HloInstruction::CreateGenerateToken({param})); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(123))); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr( + "Operands of token instructions must be TOKEN types")); +} + +XLA_TEST_F(TokenHloTest, TokenInWhileLoop) { + // Thread a token around a while loop. Token is created and consumed by a + // GenerateToken instruction in the while body. + string module_string = R"( +HloModule TokenInWhileLoop + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %generate-token = token[] generate-token(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %generate-token) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %TokenInWhileLoop () -> s32[] { + %zero = s32[] constant(0) + %init_token = token[] generate-token() + %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + + DebugOptions debug_options = GetDebugOptionsForTest(); + // Module DCE pass removes the generate token instructions. + debug_options.add_xla_disable_hlo_passes("hlo-module-dce"); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + HloRunner::CreateModuleFromString(module_string, debug_options)); + + EXPECT_TRUE(RunAndCompare(std::move(module), error_spec_)); +} + +XLA_TEST_F(TokenHloTest, TokenInConditional) { + string module_string = R"( +HloModule TokenInConditional + +%True (param.1: token[]) -> (s32[], token[]) { + %param.1 = token[] parameter(0) + %forty_two = s32[] constant(42) + ROOT %tuple = (s32[], token[]) tuple(s32[] %forty_two, token[] %param.1) +} + +%False (param.2: s32[]) -> (s32[], token[]) { + %param.2 = s32[] parameter(0) + %new_token = token[] generate-token() + ROOT %tuple = (s32[], token[]) tuple(s32[] %param.2, token[] %new_token) +} + +ENTRY %TokenInConditional (param.3: pred[]) -> s32[] { + %param.3 = pred[] parameter(0) + %init_token = token[] generate-token() + %seven = s32[] constant(7) + %cond = (s32[], token[]) conditional(pred[] %param.3, token[] %init_token, s32[] %seven), true_computation=True, false_computation=False + ROOT %root = s32[] get-tuple-element((s32[], token[]) %cond), index=0 +} +)"; + + DebugOptions debug_options = GetDebugOptionsForTest(); + // Module DCE pass removes the generate token instructions. + debug_options.add_xla_disable_hlo_passes("hlo-module-dce"); + + { + // True case. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + HloRunner::CreateModuleFromString(module_string, debug_options)); + auto arg = Literal::CreateR0(true); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + Execute(std::move(module), {arg.get()})); + EXPECT_EQ(42, result->Get({})); + } + + { + // False case. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + HloRunner::CreateModuleFromString(module_string, debug_options)); + auto arg = Literal::CreateR0(false); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + Execute(std::move(module), {arg.get()})); + EXPECT_EQ(7, result->Get({})); + } +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index 0063e7ad415e9b6718c164f415ced6fb76cbf44a..85799d4cfb4838d91bd51c8d24d7ca70b41e6df1 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -41,7 +42,12 @@ class TransferManagerTest : public LocalClientTestBase { TransferManagerTest() : shape_size_fn_([this](const Shape& shape) { return transfer_manager_->GetByteSizeRequirement(shape); - }) {} + }) { + stream_ptr_ = local_client_->mutable_backend() + ->BorrowStream(stream_executor_) + .ValueOrDie(); + stream_ = stream_ptr_.get(); + } ~TransferManagerTest() override = default; @@ -53,6 +59,10 @@ class TransferManagerTest : public LocalClientTestBase { .ValueOrDie(); } + protected: + Backend::StreamPtr stream_ptr_; + se::Stream* stream_; + private: std::function shape_size_fn_; }; @@ -63,11 +73,11 @@ XLA_TEST_F(TransferManagerTest, TransferR0U32) { auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR0Equal(42, *result); } @@ -79,11 +89,11 @@ XLA_TEST_F(TransferManagerTest, TransferR1F32) { auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR1Equal({1.25f, 2.5f, -17.0f, -20.125f}, *result); @@ -97,11 +107,11 @@ XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) { auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR1Equal(test_vector, *result); } @@ -113,11 +123,11 @@ XLA_TEST_F(TransferManagerTest, TransferR1U8) { auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_EQ(result->GetR1U8AsString(), test_string); } @@ -129,11 +139,11 @@ XLA_TEST_F(TransferManagerTest, TransferR2F32) { auto device_buffer = AllocateDeviceBuffer(shape); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); LiteralTestUtil::ExpectR2Equal( {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result); @@ -149,11 +159,11 @@ XLA_TEST_F(TransferManagerTest, // Round trip literal through device. Set the on-device layout to something // different than the literal layout. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_FALSE( LayoutUtil::Equal(result->shape().layout(), literal->shape().layout())); @@ -169,11 +179,11 @@ XLA_TEST_F(TransferManagerTest, TransferTuple) { auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } @@ -183,11 +193,11 @@ XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } @@ -203,11 +213,11 @@ XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } @@ -218,11 +228,11 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValue) { auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } @@ -237,14 +247,150 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { auto device_buffer = AllocateDeviceBuffer(literal->shape()); // Round trip literal through device. - ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, device_buffer)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - transfer_manager_->TransferLiteralFromDevice( - stream_executor_, device_buffer)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } +XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) { + const int64 kIterationCount = 5000; + std::unique_ptr literal1 = Literal::MakeTuple( + {Literal::CreateR0(123.0f).get(), + Literal::MakeTuple( + {Literal::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), + Literal::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) + .get(), + Literal::CreateR1({-10.0f, 123.0f}).get()}); + std::unique_ptr literal2 = Literal::MakeTuple( + {Literal::CreateR0(456.0f).get(), + Literal::MakeTuple( + {Literal::CreateR2({{5.0f, 7.0f}, {9.0f, 4.0f}}).get(), + Literal::CreateR1({44.0f, -11.0f, 3333333.3f}).get()}) + .get(), + Literal::CreateR1({-98.0f, 153.0f}).get()}); + + auto device_buffer1 = AllocateDeviceBuffer(literal1->shape()); + auto device_buffer2 = AllocateDeviceBuffer(literal2->shape()); + + auto stream1 = stream_; + auto stream2 = stream_->GetOrCreateSubStream(); + + std::unique_ptr result1, result2; + + // Round trip literals through device in multiple streams asynchronously. + for (int i = 0; i < kIterationCount; ++i) { + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, *literal1, + device_buffer1)); + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, *literal2, + device_buffer2)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr this_result1, + transfer_manager_->TransferLiteralFromDevice(stream1, device_buffer1)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr this_result2, + transfer_manager_->TransferLiteralFromDevice(stream2, device_buffer2)); + result1 = std::move(this_result1); + result2 = std::move(this_result2); + } + + EXPECT_TRUE(LiteralTestUtil::Equal(*literal1, *result1)); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal2, *result2)); +} + +class TransferDeviceToHostBenchmark : public TransferManagerTest { + public: + using TransferManagerTest::TransferManagerTest; + ~TransferDeviceToHostBenchmark() override {} + + void Run(int iters, int num_tuple_elements, int array_size) { + tensorflow::testing::StopTiming(); + SetUp(); + + std::vector> tuple_elements; + for (int i = 0; i < num_tuple_elements; ++i) { + tuple_elements.push_back( + Literal::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); + } + std::unique_ptr literal = + Literal::MakeTupleOwned(std::move(tuple_elements)); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer)); + } + tensorflow::testing::StopTiming(); + TearDown(); + } + + void TestBody() override {} +}; + +class TransferHostToDeviceBenchmark : public TransferManagerTest { + public: + using TransferManagerTest::TransferManagerTest; + ~TransferHostToDeviceBenchmark() override {} + + void Run(int iters, int num_tuple_elements, int array_size) { + tensorflow::testing::StopTiming(); + SetUp(); + + std::vector> tuple_elements; + for (int i = 0; i < num_tuple_elements; ++i) { + tuple_elements.push_back( + Literal::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size)); + } + std::unique_ptr literal = + Literal::MakeTupleOwned(std::move(tuple_elements)); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal, + device_buffer)); + } + tensorflow::testing::StopTiming(); + TearDown(); + } + + void TestBody() override {} +}; + +void BM_TransferDeviceToHost(int iters, int num_tuple_elements, + int array_size) { + TransferDeviceToHostBenchmark bm; + bm.Run(iters, num_tuple_elements, array_size); +} + +void BM_TransferHostToDevice(int iters, int num_tuple_elements, + int array_size) { + TransferHostToDeviceBenchmark bm; + bm.Run(iters, num_tuple_elements, array_size); +} + +BENCHMARK(BM_TransferHostToDevice) + ->ArgPair(1, 256) + ->ArgPair(1, 257) + ->ArgPair(100, 256) + ->ArgPair(100, 257); + +BENCHMARK(BM_TransferDeviceToHost) + ->ArgPair(1, 256) + ->ArgPair(1, 257) + ->ArgPair(100, 256) + ->ArgPair(100, 257); + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + tensorflow::testing::RunBenchmarks(); + return RUN_ALL_TESTS(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 41189231b90e842292830a932cf381af60456d4c..220d9f6320632cae2c51f71cc7c568120bda1f04 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -532,8 +532,8 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { auto param = Literal::MakeTupleOwned(Literal::CreateR1({1, 2, 3})); auto result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( - *result, - *Literal::MakeTupleOwned(Literal::CreateR2({{1, 2, 3}})))); + *Literal::MakeTupleOwned(Literal::CreateR2({{1, 2, 3}})), + *result)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index 3c9a01653c67203cbc962a3d3d967142f7a2102c..0be950cacbf07eece9aff9ffe1d0e571e9b25038 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -128,20 +128,23 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, se::StreamExecutor* executor = backend->default_stream_executor(); DeviceMemoryAllocator* allocator = backend->memory_allocator(); auto* transfer_manager = backend->transfer_manager(); + TF_ASSERT_OK_AND_ASSIGN( + Backend::StreamPtr stream_ptr, + backend->BorrowStream(backend->default_device_ordinal())); TF_ASSERT_OK_AND_ASSIGN( ScopedShapedBuffer lhs_arg, transfer_manager->AllocateScopedShapedBuffer( lhs_arg_shape, allocator, backend->default_device_ordinal())); TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( - executor, *Literal::CreateFromShape(lhs_arg_shape), lhs_arg)); + stream_ptr.get(), *Literal::CreateFromShape(lhs_arg_shape), lhs_arg)); TF_ASSERT_OK_AND_ASSIGN( ScopedShapedBuffer rhs_arg, transfer_manager->AllocateScopedShapedBuffer( rhs_arg_shape, allocator, backend->default_device_ordinal())); TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( - executor, *Literal::CreateFromShape(rhs_arg_shape), rhs_arg)); + stream_ptr.get(), *Literal::CreateFromShape(rhs_arg_shape), rhs_arg)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr local_executable, @@ -153,9 +156,6 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, &executable->hlo_profile_printer_data(), &executable->hlo_profile_index_map()); - TF_ASSERT_OK_AND_ASSIGN( - Backend::StreamPtr stream_ptr, - backend->BorrowStream(backend->default_device_ordinal())); ExecutableRunOptions exec_run_options; exec_run_options.set_stream(stream_ptr.get()); exec_run_options.set_allocator(backend->memory_allocator()); diff --git a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc index a9f2915b458b1816926de727b3da21982d06f6c0..a075195618c42aaa11f7b1c17730e67889a2c308 100644 --- a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc +++ b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc @@ -49,6 +49,7 @@ GTEST_API_ int main(int argc, char** argv) { } // Unfortunately Google's internal benchmark infrastructure has a // different API than Tensorflow's. + testing::InitGoogleTest(&argc, argv); #if defined(PLATFORM_GOOGLE) base::SetFlag(&FLAGS_benchmarks, pattern); RunSpecifiedBenchmarks(); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index d73bcdaf82ad5850e68b5d067fc86201a096c434..e4a052c8f1c0009619c3a94606f6384d04006e4e 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -85,6 +85,7 @@ cc_library( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:testing", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service/gpu:infeed_manager", "//tensorflow/compiler/xla/tests:test_utils", @@ -135,7 +136,7 @@ tf_cc_binary( deps = [ "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/tools/convert_computation.cc b/tensorflow/compiler/xla/tools/convert_computation.cc index fe03a6e7bdfe99877c250fe1ae22beee4c8018a2..14d01b5bfb067cc39abc4d6e0605007624b6e0ae 100644 --- a/tensorflow/compiler/xla/tools/convert_computation.cc +++ b/tensorflow/compiler/xla/tools/convert_computation.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/env.h" @@ -33,7 +33,7 @@ namespace xla { namespace tools { void RealMain(const string& mode, const string& path) { - SessionModule module; + HloSnapshot module; tensorflow::Env* env = tensorflow::Env::Default(); if (mode == "txt2bin") { TF_CHECK_OK(tensorflow::ReadTextProto(env, path, &module)); diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index be094b7890aab08c55686c4785e01ff2ffba7cc2..f7574e0b1cc95daee6d6743ba4e2e490ee87e7c6 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -24,6 +24,9 @@ limitations under the License. // passing --use_fake_data on the command line. If the real data is available // in the proto and --use_fake_data is false, the real data is used. // +// Input can be a binary HloSnapshot proto, a binary HloProto proto, or a +// textual HLO string. +// // The output format is: // // file_path: computation_name :: type:literal_str @@ -43,6 +46,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -195,25 +199,45 @@ StatusOr ReplayComputation(const HloSnapshot& module, return std::move(*result_literal); } +StatusOr ParseInputFile(const string& filename, + const Options& opts) { + tensorflow::Env* env = tensorflow::Env::Default(); + HloSnapshot snapshot; + if (tensorflow::ReadBinaryProto(env, filename, &snapshot).ok()) { + return snapshot; + } + CHECK(opts.use_fake_data) + << "Without --use_fake_data, you must pass an HloSnapshot -- HloProto " + "and textual HLO don't carry real data."; + fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n", + filename.c_str()); + + if (tensorflow::ReadBinaryProto(env, filename, snapshot.mutable_hlo()).ok()) { + return snapshot; + } + fprintf(stderr, "%s: is not HloProto. Trying HLO text.\n", filename.c_str()); + string contents; + TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(env, filename, &contents)); + StatusOr> module = ParseHloString(contents); + if (module.ok()) { + *snapshot.mutable_hlo()->mutable_hlo_module() = + module.ValueOrDie()->ToProto(); + return snapshot; + } + fprintf(stderr, "%s: is not HLO text. Nothing left to try.\n", + filename.c_str()); + return InvalidArgument("Could not parse %s.", filename.c_str()); +} + int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { LocalClient* client = ClientLibrary::LocalClientOrDie(); - tensorflow::Env* env = tensorflow::Env::Default(); int exit_status = EXIT_SUCCESS; for (char* arg : args) { - HloSnapshot snapshot; - auto status = tensorflow::ReadBinaryProto(env, arg, &snapshot); - if (!status.ok()) { - fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n", arg); - status = tensorflow::ReadBinaryProto(env, arg, snapshot.mutable_hlo()); - if (!status.ok()) { - fprintf(stderr, "%s: is not HloSnapshot or HloProto: %s.\n", arg, - status.ToString().c_str()); - continue; - } - CHECK(opts.use_fake_data) - << "HloProto input must be handled with --use_fake_data"; + StatusOr maybe_snapshot = ParseInputFile(arg, opts); + if (!maybe_snapshot.ok()) { + continue; } - + HloSnapshot snapshot = std::move(maybe_snapshot).ValueOrDie(); StatusOr result_status = ReplayComputation(snapshot, client, opts); if (!result_status.ok()) { fprintf(stderr, "%s: error: %s\n", arg, diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index b4f45cc972d3d397ddff8e8d9163d1fef387392f..6041fae1595dacb309008857f1c758ee96a646bb 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -539,6 +540,11 @@ int64 FindIndex(const C& c, Value&& value) { return std::distance(c.begin(), it); } +template +bool ArrayContains(tensorflow::gtl::ArraySlice c, const T& value) { + return c_find(c, value) != c.end(); +} + template void InsertAt(C* c, int64 index, Value&& value) { c->insert(c->begin() + index, std::forward(value)); @@ -549,6 +555,12 @@ void EraseAt(C* c, int64 index) { c->erase(c->begin() + index); } +template +std::vector InlinedVectorToVector( + const tensorflow::gtl::InlinedVector& inlined_vector) { + return std::vector(inlined_vector.begin(), inlined_vector.end()); +} + // Returns true if `x` fits in 32-bits. template bool IsInt32(T x) { diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index f619b8dc24038af64a27fc0565c74447ca9d09cf..6f07e4606bef015214f2c564515c8258a906205b 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -17,7 +17,6 @@ syntax = "proto3"; import "tensorflow/compiler/xla/xla_data.proto"; import "tensorflow/compiler/xla/service/hlo.proto"; -import "tensorflow/compiler/xla/service/session.proto"; package xla; @@ -226,22 +225,6 @@ message ExecutionOptions { repeated DeviceHandle device_handles = 5; } -message SnapshotComputationRequest { - ComputationHandle computation = 1; -} - -message SnapshotComputationResponse { - SessionModule module = 1; -} - -message LoadComputationSnapshotRequest { - SessionModule module = 1; -} - -message LoadComputationSnapshotResponse { - ComputationHandle computation = 1; -} - message GetDeviceHandlesRequest { int64 device_count = 1; } @@ -300,11 +283,6 @@ message ResetDeviceRequest { message ResetDeviceResponse { } -message ComputationStatsRequest { - ComputationHandle computation = 1; - DebugOptions debug_options = 2; -} - message ComputationGraphStatsRequest { HloModuleProto computation = 1; DebugOptions debug_options = 2; @@ -314,14 +292,6 @@ message ComputationStatsResponse { ComputationStats stats = 1; } -message ComputationRequest { - string name = 1; -} - -message ComputationResponse { - ComputationHandle computation = 1; -} - message CreateChannelHandleRequest { } @@ -336,24 +306,6 @@ message UnregisterRequest { message UnregisterResponse { } -message SetReturnValueRequest { - ComputationHandle computation = 1; - ComputationDataHandle operand = 2; -} - -message SetReturnValueResponse { -} - -message ExecuteRequest { - reserved 3, 4; - - ComputationHandle computation = 1; - repeated GlobalDataHandle arguments = 2; - - // Options that affect how XLA compiles and runs code to service this request. - ExecutionOptions execution_options = 5; -} - message ExecuteGraphRequest { HloModuleProto computation = 1; repeated GlobalDataHandle arguments = 2; @@ -362,10 +314,6 @@ message ExecuteGraphRequest { ExecutionOptions execution_options = 3; } -message ExecuteParallelRequest { - repeated ExecuteRequest requests = 1; -} - message ExecuteGraphParallelRequest { repeated ExecuteGraphRequest requests = 1; } @@ -379,21 +327,6 @@ message ExecuteParallelResponse { repeated ExecuteResponse responses = 1; } -message ExecuteAsyncRequest { - reserved 3, 4; - - ComputationHandle computation = 1; - repeated GlobalDataHandle arguments = 2; - - // Options that affect how XLA compiles and runs code to service this request. - ExecutionOptions execution_options = 6; -} - -message ExecuteAsyncResponse { - // A handle to the execution launched asynchronously. - ExecutionHandle execution = 1; -} - message WaitForExecutionRequest { ExecutionHandle execution = 1; } @@ -403,31 +336,13 @@ message WaitForExecutionResponse { ExecutionProfile profile = 2; } -message IsConstantRequest { - ComputationHandle computation = 1; - ComputationDataHandle operand = 2; - int64 num_parameters = 3; -} - -message IsConstantResponse { - bool is_constant = 1; -} - -message ComputeConstantRequest { - ComputationHandle computation = 1; - ComputationDataHandle operand = 2; - Layout output_layout = 3; - repeated LiteralProto parameters = 4; -} - message ComputeConstantGraphRequest { HloModuleProto computation = 1; Layout output_layout = 2; } message ComputeConstantResponse { - // A LiteralProto is returned directly for this request, instead of a - // ComputationDataHandle. + // A LiteralProto is returned directly for this request. LiteralProto literal = 1; } @@ -469,14 +384,6 @@ message LoadDataResponse { int64 nanoseconds = 5; } -message SpecializeRequest { - ComputationHandle computation = 1; - repeated GlobalDataHandle arguments = 2; -} - -message SpecializeResponse { -} - message GetShapeRequest { GlobalDataHandle data = 1; } @@ -485,14 +392,6 @@ message GetShapeResponse { Shape shape = 1; } -message GetComputationShapeRequest { - ComputationHandle computation = 1; -} - -message GetComputationShapeResponse { - ProgramShape program_shape = 1; -} - message UnpackRequest { GlobalDataHandle data = 1; } diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 6bdfb0179cd6a5e4eaee20cd877bd976e0e173c3..c7472173a705b7a6e1bee2f5221f23db0a77991d 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -274,12 +274,9 @@ message ExecutionProfile { // for the input data transfer since the memory is initialized with the proper // values before the execution. int64 compute_and_transfer_time_ns = 5; -} -// Handle given to a user that represents a computation that the user builds up -// before execution. -message ComputationHandle { - int64 handle = 1; + // The size of the binary code in the executable. + int64 executable_size_in_bytes = 6; } // Handle given to a user that represents an execution that the user launched @@ -295,13 +292,6 @@ message GlobalDataHandle { int64 handle = 1; } -// Handle given to a user that represents a data result in a computation. -// This is used to pass to subsequent computations that depends upon the data as -// an operand. -message ComputationDataHandle { - int64 handle = 1; -} - // Handle given to a user that represents a replicated virtual device. Each // replicated device represents N physical devices for execution where N is the // number of replicas. @@ -441,44 +431,6 @@ message GatherDimensionNumbers { int64 index_vector_dim = 4; } -// Operation requests that are all collected as a tagged union with a oneof -// field in OpRequest. - -message ConstantRequest { - LiteralProto literal = 2; -} - -message GetTupleElementRequest { - ComputationDataHandle operand = 2; - int64 index = 3; -} - -message SliceRequest { - ComputationDataHandle operand = 2; - repeated int64 start_indices = 3; - repeated int64 limit_indices = 4; - repeated int64 strides = 5; -} - -message DynamicSliceRequest { - // Operand from which to slice at dynamic 'start_indices'. - ComputationDataHandle operand = 2; - // Dynamically computed 'start_indices' for slice operation. - ComputationDataHandle start_indices = 3; - // Slice sizes for each dimension (note that indices calculations are computed - // modulo dimension sizes to avoid out-of-bound array accesses). - repeated int64 slice_sizes = 4; -} - -message DynamicUpdateSliceRequest { - // Operand on which slice 'update' is to be applied. - ComputationDataHandle operand = 2; - // The slice update to apply to 'operand'. - ComputationDataHandle update = 3; - // Dynamically computed start indices for the update slice operation. - ComputationDataHandle start_indices = 4; -} - message ConvolutionDimensionNumbers { // The number of the dimension that represents batch in the input. int64 input_batch_dimension = 7; @@ -516,13 +468,6 @@ message ConvolutionDimensionNumbers { // Next = 13 }; -message ConvolveRequest { - ComputationDataHandle lhs = 2; - ComputationDataHandle rhs = 3; // This is the filter/kernel. - Window window = 4; // Describes the filter/kernel. - ConvolutionDimensionNumbers dimension_numbers = 5; -} - enum FftType { FFT = 0; // Forward FFT; complex in, complex out. IFFT = 1; // Inverse FFT; complex in, complex out. @@ -531,56 +476,6 @@ enum FftType { // fft_length real out } -message FftRequest { - FftType fft_type = 1; - repeated int64 fft_length = 2; // Multivalent for higher-order FFT. - ComputationDataHandle operand = 3; -} - -message InfeedRequest { - // The shape of the data returned by reading the device's infeed buffer. - Shape shape = 2; - - // Additional infeed configuration for the backend. - bytes config = 3; -} - -message OutfeedRequest { - // The shape of the data returned by reading the device's outfeed buffer. - Shape shape = 1; - - // Operand to the Outfeed. Supports tuple. - ComputationDataHandle operand = 2; - - // Backend-specific information for how to perform the outfeed. - bytes outfeed_config = 3; -} - -message CallRequest { - ComputationHandle to_apply = 2; - repeated ComputationDataHandle operands = 3; -} - -message CustomCallRequest { - string call_target_name = 2; - repeated ComputationDataHandle operands = 3; - Shape shape = 4; -} - -message HostComputeRequest { - // Operand to the HostCompute. Supports tuple. - repeated ComputationDataHandle operands = 1; - - // Name used to identify HostSend/Recv channels. - string channel_name = 2; - - // Cost estimate in nanoseconds. - int64 cost_estimate_ns = 3; - - // The shape of any data returned by host. - Shape shape = 4; -} - message DotDimensionNumbers { // The dimension numbers that represent the 'lhs' contracting dimensions. repeated int64 lhs_contracting_dimensions = 1; @@ -592,297 +487,6 @@ message DotDimensionNumbers { repeated int64 rhs_batch_dimensions = 4; }; -message DotRequest { - ComputationDataHandle lhs = 2; - ComputationDataHandle rhs = 3; - DotDimensionNumbers dimension_numbers = 4; -} - -message MapRequest { - repeated ComputationDataHandle operands = 2; - ComputationHandle to_apply = 3; - repeated ComputationDataHandle static_operands = 4; - // The dimensions over which to map. - // Example mapping a Dot operation along the batch dimension 0: - // operand0.shape = [2, 2, 2], operand1.shape = [2,2,3] - // Map({operand0, operand1}, Dot, {0}) - repeated int64 dimensions = 5; -} - -message ReduceRequest { - // Operand to the reduction. - ComputationDataHandle operand = 2; - - // Initial value for the reduction. This must be consistent with the result - // shape of to_apply. - ComputationDataHandle init_value = 3; - - // The dimensions to reduce over. - repeated int64 dimensions = 4; - - // The computation to apply in the reduction. - ComputationHandle to_apply = 5; -} - -message ReduceWindowRequest { - ComputationDataHandle operand = 2; - ComputationDataHandle init_value = 3; - Window window = 4; - ComputationHandle to_apply = 5; -} - -message BatchNormTrainingRequest { - ComputationDataHandle operand = 1; - ComputationDataHandle scale = 2; - ComputationDataHandle offset = 3; - float epsilon = 4; - int64 feature_index = 5; -} - -message BatchNormInferenceRequest { - ComputationDataHandle operand = 1; - ComputationDataHandle scale = 2; - ComputationDataHandle offset = 3; - ComputationDataHandle mean = 4; - ComputationDataHandle variance = 5; - float epsilon = 6; - int64 feature_index = 7; -} - -message BatchNormGradRequest { - ComputationDataHandle operand = 1; - ComputationDataHandle scale = 2; - ComputationDataHandle mean = 3; - ComputationDataHandle variance = 4; - ComputationDataHandle grad_output = 5; - float epsilon = 6; - int64 feature_index = 7; -} - -message CrossReplicaSumRequest { - ComputationDataHandle operand = 2; -} - -message SelectAndScatterRequest { - // Operand array on which the windows slide. - ComputationDataHandle operand = 2; - - // Source array for the data to scatter. - ComputationDataHandle source = 3; - - // Initial scalar value for each element in the output. - ComputationDataHandle init_value = 4; - - // Window configuration. - Window window = 5; - - // Binary function used to select an element from each window. - ComputationHandle select = 6; - - // Binary function used to combine each scattered value from source with the - // current output value at the selected location. - ComputationHandle scatter = 7; -} - -message ReverseRequest { - ComputationDataHandle operand = 2; - repeated int64 dimensions = 3; -} - -message BroadcastRequest { - ComputationDataHandle operand = 2; - repeated int64 broadcast_sizes = 3; -} - -message PadRequest { - ComputationDataHandle operand = 2; - ComputationDataHandle padding_value = 3; - PaddingConfig padding_config = 4; -} - -message ReshapeRequest { - ComputationDataHandle operand = 2; - - // The dimension order for collapse (from fastest-changing to slowest). - repeated int64 dimensions = 3; - - // The new dimension sizes (from dimension 0 to n-1). - repeated int64 new_sizes = 4; -} - -message TransposeRequest { - ComputationDataHandle operand = 2; - - // The permutation of the operand's dimensions (in the range 0 to n-1). - repeated int64 dimensions = 3; -} - -message ParameterRequest { - Shape shape = 2; - int64 parameter = 3; - string name = 4; -} - -message GetLocalShapeRequest { - ComputationHandle computation = 1; - ComputationDataHandle operand = 2; -} - -message GetLocalShapeResponse { - Shape shape = 1; -} - -message TraceRequest { - string tag = 2; - ComputationDataHandle operand = 3; -} - -message ConvertRequest { - ComputationDataHandle operand = 2; - PrimitiveType new_element_type = 3; -} - -message ConcatenateRequest { - repeated ComputationDataHandle operands = 2; - // The dimension in which we concatenate; e.g. if you had dimension arrays of - // [4, 1] and [5, 1], you'd concatenate in dimension 0 to produce a [9, 1]. - // Attempting to concatenate those in dimension 1 would produce an error, as - // 4 != 5 (and there is no ragged array support). - int64 dimension = 3; -} - -message ConditionalRequest { - ComputationDataHandle predicate = 2; - ComputationDataHandle true_operand = 3; - ComputationHandle true_computation = 4; - ComputationDataHandle false_operand = 5; - ComputationHandle false_computation = 6; -} - -message WhileRequest { - ComputationHandle condition = 2; - ComputationHandle body = 3; - ComputationDataHandle init = 4; -} - -enum UnaryOperation { - UNOP_INVALID = 0; - - // Elementwise, logical negation on booleans and bitwise negation on ints. - UNOP_NOT = 1; - - // Elementwise, computes e^x. - UNOP_EXP = 2; - - // Elementwise, computes -x. - UNOP_NEGATE = 3; - - // Puts the elements in the operand into sorted order. - UNOP_SORT = 4; - - // Elementwise, computes tanh(x). - UNOP_TANH = 5; - - // Elementwise, computes the natural logarithm of x. - UNOP_LOG = 6; - - // Elementwise, computes the floor of x. - UNOP_FLOOR = 7; - - // Elementwise, computes the ceil of x. - UNOP_CEIL = 8; - - // Elementwise, computes the abs of x. - UNOP_ABS = 9; - - // Elementwise, computes the sign of x. - UNOP_SIGN = 10; - - // Elementwise, tests if values are finite (not NaN or inf) - UNOP_IS_FINITE = 11; - - // Elementwise, computes the cosine of x. - UNOP_COS = 12; - - // Elementwise, computes the sine of x. - UNOP_SIN = 13; - - // Elementwise, rounds x to nearest integral value, rounding half-way cases - // away from zero. - UNOP_ROUND_NEAREST_AFZ = 14; - - // Elementwise, extract real component of complex x. - UNOP_REAL = 15; - - // Elementwise, extract real component of complex x. - UNOP_IMAG = 16; - - // Elementwise, computes clz(x). - UNOP_CLZ = 17; - - // Elementwise, computes exp(x)-1. - UNOP_EXPM1 = 18; - - // Elementwise, computes log(x+1). - UNOP_LOG1P = 19; -} - -message UnaryOpRequest { - UnaryOperation unop = 2; - ComputationDataHandle operand = 3; -} - -enum BinaryOperation { - BINOP_INVALID = 0; - - // Arithmetic operations. - BINOP_ADD = 1; - BINOP_DIV = 2; - BINOP_MUL = 3; - BINOP_SUB = 4; - - // Comparison operators. - BINOP_EQ = 5; - BINOP_GE = 6; - BINOP_GT = 7; - BINOP_LE = 8; - BINOP_LT = 9; - BINOP_NE = 10; - - // Element-wise maximum. - BINOP_MAX = 14; - - // Element-wise minimum. - BINOP_MIN = 15; - - // Raises the left-hand-side to the right-hand-side power. - BINOP_POW = 16; - - // Remainder operation. - BINOP_REM = 17; - - // Element-wise, logical operators on booleans and bitwise operators on ints. - BINOP_AND = 18; - BINOP_OR = 19; - - BINOP_SHIFT_LEFT = 20; - BINOP_SHIFT_RIGHT_ARITHMETIC = 21; - BINOP_SHIFT_RIGHT_LOGICAL = 22; - - // Complex from real, imag. - BINOP_COMPLEX = 23; - - // Computes the 4-quadrant arctangent of the y, x input arguments. - BINOP_ATAN2 = 24; -} - -message BinaryOpRequest { - BinaryOperation binop = 2; - ComputationDataHandle lhs = 3; - ComputationDataHandle rhs = 4; - repeated int64 broadcast_dimensions = 5; -} - enum RandomDistribution { RNG_INVALID = 0; @@ -897,67 +501,6 @@ enum RandomDistribution { // Next: 4 } -message RngRequest { - RandomDistribution distribution = 2; - repeated ComputationDataHandle parameter = 3; - Shape shape = 4; -} - -enum TernaryOperation { - TRIOP_INVALID = 0; - - // Given a predicate and two operands, selects operand0 if the predicate is - // true and operand1 if the predicate is false. - TRIOP_SELECT = 1; - - // Given a min, max and an operand returns the operand if between min and max, - // else returns min if operand is less than min or max if operand is greater - // than max. - TRIOP_CLAMP = 3; -} - -message TernaryOpRequest { - TernaryOperation triop = 2; - ComputationDataHandle lhs = 3; - ComputationDataHandle rhs = 4; - ComputationDataHandle ehs = 5; -} - -enum VariadicOperation { - VAROP_INVALID = 0; - - // Creates a tuple from its operands. - VAROP_TUPLE = 1; -} - -message VariadicOpRequest { - VariadicOperation varop = 2; - repeated ComputationDataHandle operands = 3; -} - -message ReducePrecisionRequest { - ComputationDataHandle operand = 1; - int32 exponent_bits = 2; - int32 mantissa_bits = 3; -} - -message SendRequest { - ComputationDataHandle operand = 1; - ChannelHandle channel_handle = 2; -} - -message RecvRequest { - Shape shape = 1; - ChannelHandle channel_handle = 2; -} - -message GatherRequest { - ComputationDataHandle input = 1; - ComputationDataHandle gather_indices = 2; - GatherDimensionNumbers dimension_numbers = 3; - repeated int64 window_bounds = 4; -} - message OpSharding { enum Type { // This sharding is replicated across all devices (implies maximal, @@ -988,59 +531,3 @@ message OpSharding { // to. repeated OpSharding tuple_shardings = 5; } - -message OpRequest { - ComputationHandle computation = 1; - OpMetadata metadata = 33; - OpSharding sharding = 40; - - oneof op { - BinaryOpRequest binary_op_request = 2; - BroadcastRequest broadcast_request = 3; - CallRequest call_request = 4; - ConcatenateRequest concatenate_request = 5; - ConstantRequest constant_request = 6; - ConvertRequest convert_request = 7; - ConvolveRequest convolve_request = 8; - CrossReplicaSumRequest cross_replica_sum_request = 9; - CustomCallRequest custom_call_request = 10; - DotRequest dot_request = 43; - DynamicSliceRequest dynamic_slice_request = 11; - DynamicUpdateSliceRequest dynamic_update_slice_request = 12; - GetTupleElementRequest get_tuple_element_request = 13; - InfeedRequest infeed_request = 14; - MapRequest map_request = 15; - PadRequest pad_request = 16; - ParameterRequest parameter_request = 17; - ReducePrecisionRequest reduce_precision_request = 36; - ReduceRequest reduce_request = 18; - ReduceWindowRequest reduce_window_request = 19; - ReshapeRequest reshape_request = 20; - ReverseRequest reverse_request = 21; - RngRequest rng_request = 22; - SelectAndScatterRequest select_and_scatter_request = 23; - SliceRequest slice_request = 24; - TernaryOpRequest ternary_op_request = 25; - TraceRequest trace_request = 26; - TransposeRequest transpose_request = 34; - UnaryOpRequest unary_op_request = 27; - VariadicOpRequest variadic_op_request = 28; - WhileRequest while_request = 29; - SendRequest send_request = 30; - RecvRequest recv_request = 31; - OutfeedRequest outfeed_request = 32; - BatchNormTrainingRequest batch_norm_training_request = 35; - BatchNormGradRequest batch_norm_grad_request = 37; - BatchNormInferenceRequest batch_norm_inference_request = 38; - FftRequest fft_request = 41; - ConvertRequest bitcast_convert_request = 42; - ConditionalRequest conditional_request = 44; - HostComputeRequest host_compute_request = 45; - GatherRequest gather_request = 46; - // Next: 47 - } -} - -message OpResponse { - ComputationDataHandle output = 1; -} diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 0f9c80404ad33c39ae783e0bfa3cfb26e342fe3d..fffab5a79549bfa2d74bea227b6e0245834a84c2 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -31,13 +31,14 @@ py_library( "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/coder:coder_py", "//tensorflow/contrib/compiler:compiler_py", + "//tensorflow/contrib/autograph", "//tensorflow/contrib/constrained_optimization", "//tensorflow/contrib/copy_graph:copy_graph_py", "//tensorflow/contrib/crf:crf_py", "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py", "//tensorflow/contrib/data", - "//tensorflow/contrib/distribute:distribute", "//tensorflow/contrib/deprecated:deprecated_py", + "//tensorflow/contrib/distribute:distribute", "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/eager/python:tfe", "//tensorflow/contrib/estimator:estimator_py", @@ -83,7 +84,6 @@ py_library( "//tensorflow/contrib/proto", "//tensorflow/contrib/quantization:quantization_py", "//tensorflow/contrib/quantize:quantize_graph", - "//tensorflow/contrib/autograph", "//tensorflow/contrib/receptive_field:receptive_field_py", "//tensorflow/contrib/recurrent:recurrent_py", "//tensorflow/contrib/reduce_slice_ops:reduce_slice_ops_py", @@ -114,6 +114,7 @@ py_library( "//tensorflow/contrib/training:training_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:util", + "//tensorflow/python/estimator:estimator_py", ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_tensorrt([ "//tensorflow/contrib/tensorrt:init_py", ]) + select({ diff --git a/tensorflow/contrib/all_reduce/BUILD b/tensorflow/contrib/all_reduce/BUILD index 62d1b1cf079d04d50e4899cfd9ba1d405ee1efb9..881808a98bfd688c2efaa8beb5b8f11a2527fee8 100644 --- a/tensorflow/contrib/all_reduce/BUILD +++ b/tensorflow/contrib/all_reduce/BUILD @@ -11,6 +11,16 @@ exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "tf_py_test") +py_library( + name = "all_reduce_py", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":all_reduce", + "//tensorflow/python:util", + ], +) + py_library( name = "all_reduce", srcs = [ diff --git a/tensorflow/contrib/all_reduce/__init__.py b/tensorflow/contrib/all_reduce/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f9824f4cfbf83d9b001a58cafe582226e96c076f --- /dev/null +++ b/tensorflow/contrib/all_reduce/__init__.py @@ -0,0 +1,39 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""All-reduce implementations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.all_reduce.python.all_reduce import * + +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + 'build_ring_all_reduce', + 'build_recursive_hd_all_reduce', + 'build_shuffle_all_reduce', + 'build_nccl_all_reduce', + 'build_nccl_then_ring', + 'build_nccl_then_recursive_hd', + 'build_nccl_then_shuffle', + 'build_shuffle_then_ring', + 'build_shuffle_then_shuffle' +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/android/BUILD b/tensorflow/contrib/android/BUILD index c10179ba8b290b6209f5567d6323df4bcf711585..f0b1c92cf7e4b760381da38febd9682ce2a4f27c 100644 --- a/tensorflow/contrib/android/BUILD +++ b/tensorflow/contrib/android/BUILD @@ -1,6 +1,8 @@ # Description: # JNI-based Java inference interface for TensorFlow. +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/autograph/BUILD b/tensorflow/contrib/autograph/BUILD index 30dd846893c30b9205972bd5216cc1871ab03d76..ad700ac4a0342e2a7bc07a6ecf6710cea892e296 100644 --- a/tensorflow/contrib/autograph/BUILD +++ b/tensorflow/contrib/autograph/BUILD @@ -23,9 +23,9 @@ py_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/contrib/autograph/impl", + "//tensorflow/contrib/autograph/lang", "//tensorflow/contrib/autograph/pyct", "//tensorflow/contrib/autograph/utils", - "@gast_archive//:gast", - "@six_archive//:six", + "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/autograph/CONTRIBUTING.md b/tensorflow/contrib/autograph/CONTRIBUTING.md index a4aec8c74a9ad1418072471a5d3cde8c3b968a38..06fb7b03d5dbbfd2fcb6d6a2ecfe5c817f94a469 100644 --- a/tensorflow/contrib/autograph/CONTRIBUTING.md +++ b/tensorflow/contrib/autograph/CONTRIBUTING.md @@ -1,4 +1,4 @@ -# How to Contribute +# How to contribute We'd love to have your patches and contributions! Here are some guidelines. In general, we follow the [TensorFlow contributing guidelines](../../CONTRIBUTING.md), but have some [AutoGraph-specific style guidelines](STYLE_GUIDE.md). More details below. @@ -46,3 +46,50 @@ bazel test --config=opt --copt=-O3 --copt=-march=native \ ``` from the root of the `tensorflow` repository. For more details see the [main TensorFlow Contributing File](../../CONTRIBUTING.md) + +## Developer info + +### Module structure + +The graph below describes the dependencies between AutoGraph modules (not to be mistaken with the directory structure for these modules, which is flat): + +```dot +digraph d_modules { + autograph [style=filled]; + converters; + core; + impl; + lang; + operators; + + autograph -> impl + autograph -> lang + + impl -> converters + impl -> core + impl -> operators + + lang -> operators + + converters -> core + converters -> lang +} +``` + +`autograph` is the sole user-visible module. + +A short description of the modules: + + * `autograph`: the main module imported by the user and by the generated code; only contains declarations + * `impl`: high level code and the implementation of the api frontend + * `core`: base classes for the AutoGraph source code transformation logic; see in particular `converter.py` + * `lang`: special user-visible functions that serve as extensions to the Python language + * `converters`: collection of source code transformation modules specialized for particular AutoGraph features + * `operators`: collection of operators that AutoGraph overloads; these correspond to Python operators as well as Python syntactic structures, like control flow + +There are two additional modules, `pyct` and `utils`. These are independent of AutoGraph: + + * `pyct`: a general purpose Python source code transformation library + * `utils`: the kitchen sync; deprecated + +Note: we have a long term plan to factor out an implementation of `impl` and `converters` that is independent of autograph, into a general purpose Python operator overloading library. diff --git a/tensorflow/contrib/autograph/LIMITATIONS.md b/tensorflow/contrib/autograph/LIMITATIONS.md new file mode 100644 index 0000000000000000000000000000000000000000..d8b1cb7616ac348981bf2b69d6e2fd8d8a6e6b78 --- /dev/null +++ b/tensorflow/contrib/autograph/LIMITATIONS.md @@ -0,0 +1,50 @@ +# Capabilities and Limitations + +TF AutoGraph converts Eager Python code into TensorFlow graph-mode code. For example, users write code with `if` and `while` and AutoGraph automatically converts it into the equivalent `tf.cond`, and `tf.while_loop`. + +Python is a large language, so hoping to convert arbitrary Python code directly to TF graphs is overly ambitious. However, the Python code written to metaprogram TF graphs is in practice a restricted subset. We aim to support as much of this subset as possible. The table below lays out what we currently handle, what we hope to support, and what we have no plans to support. + +# Python Language Support Status + +Note: as more complex features in TensorFlow are made more accessible using AutoGraph, we expect to come across use cases that haven't been tried before, some of which might reveal rare bugs. If we do find any such bugs, we may add additional restrictions for the affected configurations, until those bugs are resolved. + + Construct | Supported now? | Plan to support? | Notes + :--------- | :--------------: | :----------------: | :----- +If statement | Yes | | Converts to `tf.cond`. If variables are created in one branch that don’t exist in another, which is inexpressible in TF, we throw a clear error. +For statement | Yes | | We will specialize `for` loops with unknown and known lengths, as well as for loops over TF datasets. Converts to `tf.while_loop`, with an additional `maximum_iterations` hint, if that is known. Creating variables inside the loop that are used later outside the loop is not supported, as the loop may have no iterations. +While statement | Yes | | Converts to `tf.while_loop`. Creating variables inside the loop is not supported, as the loop may have no iterations. +Continue and break | Yes | | Converts to boolean flags and extra predicates in loop tests. +Composition of control flow | Yes | | Arbitrary composition of `if`, `while`, `for`, `break`, and `continue`, along with other supported language elements, is supported and tested. +Iterators | Some | Yes | Not all iterators supported, but we plan to support everything that can be desugared, such as `enumerate` and `zip`. +Multiple return values | Yes | | We desugar them into variables, boolean flags and conditionals so that the function has a single return value at the end, and provide a clear error if we are unable to do so. +Print expression | Yes | | Wrapped in `PyFunc`, and given proper control dependencies. Optional support for using tf.Log when py_func is undesirable exists. +Static function calls | Yes | | Non-recursive function calls +Nested call trees | Yes | | For example, `f` calls `g` which calls `h`, all of which need conversion. +Recursive function calls | No | Maybe | Based on available support in TF. Currently `function.Defun` is the best candidate, but it is not reentrant. +Python built-ins | Some | Yes | `print`, `len`, `range`, `xrange`, `int`, `float` are supported, and we plan to support or clearly error on all [Python built-ins](https://docs.python.org/3/library/functions.html). +List operations | Yes | | We convert list creation, append, pop and indexing to their TF TensorArray equivalents. However, we do need some extra type hints to fully convert correctly. We hope to remove this limitation. +Function variables | Yes | | e.g. `f_new = f_orig; f_new()` +Lambda functions | No | Yes | Planned feature. +Classes | Yes | | Classes can be converted all at once, or method-by-method. Some limitations exist around static and class methods. +Subclasses | Yes | | Subclassing library objects like tf.keras.Model is also supported. +Dynamic types | Some | | `o = C1() if foo else C2(); o.bar()`. Some scenarios where types are data-dependent may not be supported. We will raise a meaningful error in that case. +Dynamic code / exec | No | | +Reflection | No | | +Try / Except | No | No | No current sane TF equivalent. +Global variables | Restricted | | In general, we only support read-only access to arguments or variables defined outside the converted code. A few exceptions include TensorFlow library code. +Functions with side effects | Some | | Side effects are allowed, under certain circumstances. +Collections | Some | Yes | We currently support lists. There are currently no TF equivalents of dictionaries or tuples. +List Comprehensions | Yes | | We desugar `ListComp` into the appropriate combination of `For` and `If` statements. Other comprehensions are currently very low priority. +Custom context managers | No | Yes | Currently low priority. Left unconverted currently. +Generators | No | Maybe | Could be achievable using queues; very low priority. +Assertions | Yes | | As `tf.Assert` +Deletion | Yes | Maybe | Currently unconverted. If new semanti cs are required for `del`, we are able to add it in. +Inline imports | No | Yes | For example, `import numpy as np; np.eye(3)`. Currently low priority. +Async | No | No | + +## Extra capabilities + + - We liberally add name scopes to generated functions + - Operations get decent default names everywhere (planned) + - Statements that have no output values are given correct control dependencies. For example, `for i in range(n): print(i)` will have control dependencies to ensure the `print` statements are executed serially. + diff --git a/tensorflow/contrib/autograph/README.md b/tensorflow/contrib/autograph/README.md index 674859bed4ec157d5d5b33b6fc015c930e54b392..829a57d8e61ee4a41076f7397488cd85bdca1376 100644 --- a/tensorflow/contrib/autograph/README.md +++ b/tensorflow/contrib/autograph/README.md @@ -120,3 +120,15 @@ You can use the functional API to inspect the generated code as well: print(ag.to_code(f)) # Output: ``` + +## Filing bugs and feature requests + +### Reporting a bug + + - If AutoGraph-generated code is compiling and running, but producing an incorrect result, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message. + - If AutoGraph-generated code is compiling, but not running, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message. + - If AutoGraph-generated code is not compiling, send us two minimal pieces of code. First, the Eager code that you would like to write, and second, the Graph code that you would like AutoGraph to have generated for you. + +### Requesting a feature + +If you’d like AutoGraph to convert a feature of Python or TF that we currently don’t handle, please let us know by filing a bug. We’ll make it as easy as possible to interact with us through there. diff --git a/tensorflow/contrib/autograph/STYLE_GUIDE.md b/tensorflow/contrib/autograph/STYLE_GUIDE.md index 866e5f583a34570dfddc733f57561ed1d2b7c5bf..7e6b0cc27dd1cf8c0f459a0a34f98092728342a2 100644 --- a/tensorflow/contrib/autograph/STYLE_GUIDE.md +++ b/tensorflow/contrib/autograph/STYLE_GUIDE.md @@ -20,7 +20,17 @@ Naming conventions: Below are AutoGraph-specific conventions. In the event of conflict, it supercedes all previous conventions. -1. __Citations in Docstrings.__ Write a `#### References` subsection at the +1. __Types in docstrings.__ Use [PEP 484][https://www.python.org/dev/peps/pep-0484/] + notation to describe the type for args, return values and attributes. + + Example: + + ``` + Args: + foo: Dict[str, List[int]], a dictionary of sorts + ``` + +2. __Citations in Docstrings.__ Write a `#### References` subsection at the bottom of any docstring with citations. Use ICLR’s bibliography style to write references; for example, order entries by the first author's last name. Add a link to the paper if the publication is open source (ideally, @@ -60,12 +70,12 @@ it supercedes all previous conventions. https://arxiv.org/abs/1803.04386 ``` -2. Avoid LaTeX in docstrings. +3. Avoid LaTeX in docstrings. * It is not rendered in many (if not most) editors and can be hard to read for both LaTeX experts and non-experts. -3. Write docstring and comment math using ASCII friendly notation; python using +4. Write docstring and comment math using ASCII friendly notation; python using operators. E.g., `x**2` better than `x^2`, `x[i, j]` better than `x_{i,j}`, `sum{ f(x[i]) : i=1...n }` better than `\sum_{i=1}^n f(x_i)` `int{sin(x) dx: x in [0, 2 pi]}` better than `\int_0^{2\pi} sin(x) dx`. diff --git a/tensorflow/contrib/autograph/__init__.py b/tensorflow/contrib/autograph/__init__.py index 79d73af98097aea418f2116aee40b2572b418ef7..361cf2d77c7e46912d5bff5881df2ffa897c5179 100644 --- a/tensorflow/contrib/autograph/__init__.py +++ b/tensorflow/contrib/autograph/__init__.py @@ -30,7 +30,9 @@ from tensorflow.contrib.autograph.impl.api import do_not_convert from tensorflow.contrib.autograph.impl.api import RunMode from tensorflow.contrib.autograph.impl.api import to_code from tensorflow.contrib.autograph.impl.api import to_graph -from tensorflow.contrib.autograph.impl.special_functions import stack +from tensorflow.contrib.autograph.lang.directives import set_element_type +from tensorflow.contrib.autograph.lang.directives import set_loop_options +from tensorflow.contrib.autograph.lang.special_functions import stack from tensorflow.contrib.autograph.pyct.transformer import AutographParseError from tensorflow.python.util.all_util import remove_undocumented @@ -42,8 +44,11 @@ _allowed_symbols = [ 'do_not_convert', 'to_code', 'to_graph', - # Special functions and overloaded operators + # Overloaded operators 'operators', + # Python language "extensions" + 'set_element_type', + 'set_loop_options', 'stack', # Exceptions 'AutographParseError', diff --git a/tensorflow/contrib/autograph/converters/BUILD b/tensorflow/contrib/autograph/converters/BUILD index 8f9bffa55e44e4942bb3845945b3d440c7957cc9..b2e2e27673dafe290cef40a9fe0a834bfe1ea61f 100644 --- a/tensorflow/contrib/autograph/converters/BUILD +++ b/tensorflow/contrib/autograph/converters/BUILD @@ -31,29 +31,17 @@ py_library( "name_scopes.py", "side_effect_guards.py", "single_return.py", + "slices.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ - "@gast_archive//:gast", - ], -) - -py_library( - name = "test_lib", - srcs = [ - "converter_test_base.py", - ], - srcs_version = "PY2AND3", - visibility = ["//tensorflow:__subpackages__"], - deps = [ - ":converters", - "//tensorflow/contrib/autograph/operators", + "//tensorflow/contrib/autograph/core", + "//tensorflow/contrib/autograph/lang", "//tensorflow/contrib/autograph/pyct", "//tensorflow/contrib/autograph/pyct/static_analysis", - "//tensorflow/contrib/autograph/utils", + "//tensorflow/python:util", "@gast_archive//:gast", - "@six_archive//:six", ], ) @@ -63,7 +51,8 @@ py_test( srcs_version = "PY2AND3", tags = ["no_windows"], deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -73,7 +62,8 @@ py_test( srcs = ["break_statements_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -84,7 +74,8 @@ py_test( srcs_version = "PY2AND3", tags = ["no_windows"], deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -96,7 +87,8 @@ py_test( srcs_version = "PY2AND3", tags = ["no_windows"], deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/contrib/autograph/impl", "//tensorflow/python:client_testlib", ], @@ -107,7 +99,8 @@ py_test( srcs = ["continue_statements_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -117,7 +110,8 @@ py_test( srcs = ["control_flow_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -126,8 +120,13 @@ py_test( name = "decorators_test", srcs = ["decorators_test.py"], srcs_version = "PY2AND3", + tags = [ + "no_pip", + "no_windows", + ], deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -136,7 +135,8 @@ py_test( name = "name_scopes_test", srcs = ["name_scopes_test.py"], deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], @@ -147,7 +147,8 @@ py_test( srcs = ["list_comprehension_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -157,7 +158,8 @@ py_test( srcs = ["lists_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -167,7 +169,8 @@ py_test( srcs = ["logical_expressions_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -182,7 +185,8 @@ py_test( "notap", ], deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/python:client_testlib", ], ) @@ -192,7 +196,8 @@ py_test( srcs = ["single_return_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], @@ -203,7 +208,20 @@ py_test( srcs = ["ifexp_test.py"], srcs_version = "PY2AND3", deps = [ - ":test_lib", + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "slices_test", + srcs = ["slices_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":converters", + "//tensorflow/contrib/autograph/core:test_lib", "//tensorflow/contrib/autograph/pyct", "//tensorflow/python:client_testlib", ], diff --git a/tensorflow/contrib/autograph/converters/asserts.py b/tensorflow/contrib/autograph/converters/asserts.py index 3b0db677ce5e417e7afea8d8fe4121a0352bb6d7..e664a403a5fb800e7d0dddfa5695330927aaf4e0 100644 --- a/tensorflow/contrib/autograph/converters/asserts.py +++ b/tensorflow/contrib/autograph/converters/asserts.py @@ -20,11 +20,11 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer -class AssertsTransformer(transformer.Base): +class AssertsTransformer(converter.Base): """Transforms Print nodes to Call so they can be handled as functions.""" def visit_Assert(self, node): @@ -45,5 +45,5 @@ class AssertsTransformer(transformer.Base): raise NotImplementedError('can only convert string messages for now.') -def transform(node, context): - return AssertsTransformer(context).visit(node) +def transform(node, ctx): + return AssertsTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/asserts_test.py b/tensorflow/contrib/autograph/converters/asserts_test.py index cc913febe8d0f411588af69b87ec52ce58f4469c..2cd0e626bc4552bd40bc94b890fdcc7efcafb3f3 100644 --- a/tensorflow/contrib/autograph/converters/asserts_test.py +++ b/tensorflow/contrib/autograph/converters/asserts_test.py @@ -21,11 +21,11 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph.converters import asserts -from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.platform import test -class AssertsTest(converter_test_base.TestCase): +class AssertsTest(converter_testing.TestCase): def test_transform(self): diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py index 775d92c1d9f8bc35d1eda62f3f3ef7ee43414779..a990e359a2a25a57ee2a4f8a866350633f3b9ea8 100644 --- a/tensorflow/contrib/autograph/converters/break_statements.py +++ b/tensorflow/contrib/autograph/converters/break_statements.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno @@ -29,7 +29,7 @@ BREAK_USED = 'break_used' CONTROL_VAR_NAME = 'control_var_name' -class BreakStatementTransformer(transformer.Base): +class BreakStatementTransformer(converter.Base): """Canonicalizes break statements into additional conditionals.""" def visit_Break(self, node): @@ -67,7 +67,7 @@ class BreakStatementTransformer(transformer.Base): def visit_While(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - break_var = self.context.namer.new_symbol('break_', scope.referenced) + break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.test = self.visit(node.test) node.body, break_used = self._track_body(node.body, break_var) @@ -97,7 +97,7 @@ class BreakStatementTransformer(transformer.Base): def visit_For(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - break_var = self.context.namer.new_symbol('break_', scope.referenced) + break_var = self.ctx.namer.new_symbol('break_', scope.referenced) node.target = self.visit(node.target) node.iter = self.visit(node.iter) @@ -137,5 +137,5 @@ class BreakStatementTransformer(transformer.Base): return node -def transform(node, context): - return BreakStatementTransformer(context).visit(node) +def transform(node, ctx): + return BreakStatementTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/break_statements_test.py b/tensorflow/contrib/autograph/converters/break_statements_test.py index 1af59e9b5260fe0d3a3ef72c7a003dc451e230f3..dcff1c54c2f9300d58d217517e108d634ae85fb4 100644 --- a/tensorflow/contrib/autograph/converters/break_statements_test.py +++ b/tensorflow/contrib/autograph/converters/break_statements_test.py @@ -19,11 +19,11 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.autograph.converters import break_statements -from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.platform import test -class BreakCanonicalizationTest(converter_test_base.TestCase): +class BreakCanonicalizationTest(converter_testing.TestCase): def test_basic_while(self): diff --git a/tensorflow/contrib/autograph/converters/builtin_functions.py b/tensorflow/contrib/autograph/converters/builtin_functions.py index 231e4ee35a72f51845a476d9f605986ac73b4676..b26c52294c2d1c11ce14d8a2903f7f88079a703f 100644 --- a/tensorflow/contrib/autograph/converters/builtin_functions.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions.py @@ -20,11 +20,11 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer -class BuiltinFunctionTransformer(transformer.Base): +class BuiltinFunctionTransformer(converter.Base): """Handles builtin functions. This transformer only covers functions that are translated into a @@ -68,5 +68,5 @@ class BuiltinFunctionTransformer(transformer.Base): return self.visit(function_call) -def transform(node, context): - return BuiltinFunctionTransformer(context).visit(node) +def transform(node, ctx): + return BuiltinFunctionTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/builtin_functions_test.py b/tensorflow/contrib/autograph/converters/builtin_functions_test.py index 30272409df322560b04ba75b3e1cb6f9ad5ff0af..e9000e518ce14f9e0ea486d5b3e374439b8c78ca 100644 --- a/tensorflow/contrib/autograph/converters/builtin_functions_test.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions_test.py @@ -23,13 +23,13 @@ import sys import six from tensorflow.contrib.autograph.converters import builtin_functions -from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class BuiltinFunctionsTest(converter_test_base.TestCase): +class BuiltinFunctionsTest(converter_testing.TestCase): def test_len(self): diff --git a/tensorflow/contrib/autograph/converters/call_trees.py b/tensorflow/contrib/autograph/converters/call_trees.py index b6ecdcb7809b1ad7e7461324cb6a110ef4180609..a36b3d77a9233daed864c616306b2ad27f582a38 100644 --- a/tensorflow/contrib/autograph/converters/call_trees.py +++ b/tensorflow/contrib/autograph/converters/call_trees.py @@ -26,12 +26,12 @@ from collections import namedtuple import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import inspect_utils from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.util import tf_inspect @@ -45,6 +45,9 @@ KNOWN_NUMPY_FUNCTIONS = { } +# TODO(mdan): Get rid of these interfaces. Can now depend directly on Namer. + + class FunctionNamer(object): """Describes the interface for CallTreeTransformer's namer.""" @@ -76,20 +79,18 @@ class FunctionNamer(object): raise NotImplementedError() -class CallTreeTransformer(transformer.Base): - """Transforms the call tree by renaming transformed symbols.""" +# TODO(mdan): Rename to CallsTransformer. - def __init__(self, context, uncompiled_modules, nocompile_decorators): - super(CallTreeTransformer, self).__init__(context) - self.uncompiled_modules = uncompiled_modules - self.nocompile_decorators = nocompile_decorators + +class CallTreeTransformer(converter.Base): + """Transforms the call tree by renaming transformed symbols.""" def _resolve_name(self, node): """Used to resolve decorator info.""" if isinstance(node, gast.Call): return self._resolve_name(node.func) if isinstance(node, gast.Name): - return self.context.namespace.get(node.id) + return self.ctx.namespace.get(node.id) if isinstance(node, gast.Attribute): parent = self._resolve_name(node.value) if parent is not None: @@ -119,12 +120,12 @@ class CallTreeTransformer(transformer.Base): """Determines whether an entity should be compiled in the context.""" # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether. module_name = fqn[0] - for mod in self.uncompiled_modules: + for mod in self.ctx.program.uncompiled_modules: if module_name.startswith(mod[0] + '.'): return False for i in range(1, len(fqn)): - if fqn[:i] in self.uncompiled_modules: + if fqn[:i] in self.ctx.program.uncompiled_modules: return False # Check for local decorations @@ -140,7 +141,7 @@ class CallTreeTransformer(transformer.Base): if hasattr(target_entity, '__pyct_is_compile_decorator'): return False - if target_entity in self.nocompile_decorators: + if target_entity in self.ctx.program.autograph_decorators: return False # Inspect the target function decorators. If any include a @convert @@ -159,7 +160,7 @@ class CallTreeTransformer(transformer.Base): for dec in target_node.decorator_list: decorator_fn = self._resolve_name(dec) if (decorator_fn is not None and - decorator_fn in self.nocompile_decorators): + decorator_fn in self.ctx.program.autograph_decorators): return False return True @@ -174,7 +175,7 @@ class CallTreeTransformer(transformer.Base): return node if anno.hasanno(node, 'is_constructor'): - new_name = self.context.namer.compiled_class_name( + new_name = self.ctx.namer.compiled_class_name( target_fqn, live_entity=target_entity) do_rename = True else: @@ -183,7 +184,7 @@ class CallTreeTransformer(transformer.Base): else: # Fallback - not reliable. owner_type = inspect_utils.getmethodclass(target_entity) - new_name, do_rename = self.context.namer.compiled_function_name( + new_name, do_rename = self.ctx.namer.compiled_function_name( target_fqn, live_entity=target_entity, owner_type=owner_type) if do_rename: @@ -264,15 +265,16 @@ class CallTreeTransformer(transformer.Base): return node def visit_Call(self, node): - # If the function is wrapped by one of the marker decorators, + # If the function call is wrapped by one of the marker decorators, # consider it graph ready. if anno.hasanno(node.func, 'live_val'): target_entity = anno.getanno(node.func, 'live_val') - if target_entity in self.nocompile_decorators: + if target_entity in self.ctx.program.autograph_decorators: if len(node.args) < 1: raise ValueError( 'Found call to decorator function "%s", but it had no arguments. ' - 'A decorator needs at least an argument.') + 'A decorator needs at least one positional argument.' % + target_entity) anno.setanno(node.args[0], 'graph_ready', True) self.generic_visit(node) @@ -309,27 +311,20 @@ class CallTreeTransformer(transformer.Base): # ensure that they return the correct value. return node - if self.context.recursive: + if self.ctx.program.recursive: node = self._insert_dynamic_conversion(node) return node -def transform(node, context, uncompiled_modules, nocompile_decorators): +def transform(node, ctx): """Transform function call to the compiled counterparts. Args: - node: AST to transform. - context: An EntityContext object. - uncompiled_modules: set of string tuples, each tuple represents the fully - qualified name of a package containing functions that will not be - compiled. - nocompile_decorators: A tuple containing decorators to be stripped from - functions during conversion. + node: AST + ctx: EntityContext Returns: A tuple (node, new_names): node: The transformed AST new_names: set(string), containing any newly-generated names """ - t = CallTreeTransformer(context, uncompiled_modules, nocompile_decorators) - node = t.visit(node) - return node + return CallTreeTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/call_trees_test.py b/tensorflow/contrib/autograph/converters/call_trees_test.py index 303dd54a4ee49de27fad0c5cdc2d6274abfe0fa8..27d8281b856f505062ceacc8ad50c8cbc2ce6c81 100644 --- a/tensorflow/contrib/autograph/converters/call_trees_test.py +++ b/tensorflow/contrib/autograph/converters/call_trees_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.autograph.converters import call_trees -from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -29,7 +29,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class CallTreesTest(converter_test_base.TestCase): +class CallTreesTest(converter_testing.TestCase): def test_basic(self): @@ -43,7 +43,7 @@ class CallTreesTest(converter_test_base.TestCase): return test_fn_1(a) + 1 node = self.parse_and_analyze(test_fn_2, {'test_fn_1': test_fn_1}) - node = call_trees.transform(node, self.ctx, (), ()) + node = call_trees.transform(node, self.ctx) with self.compiled(node) as result: # Only test_fn_2 is transformed, so we'll insert renamed_test_fn_1 @@ -60,7 +60,7 @@ class CallTreesTest(converter_test_base.TestCase): return f() + 3 node = self.parse_and_analyze(test_fn_2, {}) - node = call_trees.transform(node, self.ctx, (), ()) + node = call_trees.transform(node, self.ctx) with self.compiled(node) as result: # 10 = 7 (from the mock) + 3 (from test_fn_2) @@ -78,9 +78,9 @@ class CallTreesTest(converter_test_base.TestCase): node = self.parse_and_analyze( TestClass.test_fn_2, {'TestClass': TestClass}, - namer=converter_test_base.FakeNoRenameNamer(), + namer=converter_testing.FakeNoRenameNamer(), arg_types={'self': (TestClass.__name__, TestClass)}) - node = call_trees.transform(node, self.ctx, (), ()) + node = call_trees.transform(node, self.ctx) with self.compiled(node) as result: tc = TestClass() @@ -92,7 +92,7 @@ class CallTreesTest(converter_test_base.TestCase): setattr(a, 'foo', 'bar') node = self.parse_and_analyze(test_fn, {'setattr': setattr}) - node = call_trees.transform(node, self.ctx, (), ()) + node = call_trees.transform(node, self.ctx) with self.compiled(node) as result: with self.test_session() as sess: @@ -115,7 +115,7 @@ class CallTreesTest(converter_test_base.TestCase): return np.random.binomial(2, 0.5) node = self.parse_and_analyze(test_fn, {'np': np}) - node = call_trees.transform(node, self.ctx, (), ()) + node = call_trees.transform(node, self.ctx) with self.compiled(node, dtypes.int64) as result: result.np = np @@ -130,13 +130,13 @@ class CallTreesTest(converter_test_base.TestCase): a = math_ops.add(a, constant_op.constant(1)) return a - node = self.parse_and_analyze(test_fn, { - 'math_ops': math_ops, - 'constant_op': constant_op - }) - node = call_trees.transform(node, self.ctx, - set(((math_ops.__name__,), - (constant_op.__name__,))), ()) + node = self.parse_and_analyze( + test_fn, { + 'math_ops': math_ops, + 'constant_op': constant_op + }, + arg_types=set(((math_ops.__name__,), (constant_op.__name__,)))) + node = call_trees.transform(node, self.ctx) with self.compiled(node) as result: result.math_ops = math_ops diff --git a/tensorflow/contrib/autograph/converters/continue_statements.py b/tensorflow/contrib/autograph/converters/continue_statements.py index 0417817a77e706fc0ce805f7391bea600f5fbb2d..958bde0a58764e705c35ab73ce879b2c11ce7cdc 100644 --- a/tensorflow/contrib/autograph/converters/continue_statements.py +++ b/tensorflow/contrib/autograph/converters/continue_statements.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno @@ -31,7 +31,7 @@ GUARD_CREATED = 'guard_created' CREATE_GUARD_NEXT = 'create_guard_next' -class ContinueCanonicalizationTransformer(transformer.Base): +class ContinueCanonicalizationTransformer(converter.Base): """Canonicalizes continue statements into additional conditionals.""" def visit_Continue(self, node): @@ -85,7 +85,7 @@ class ContinueCanonicalizationTransformer(transformer.Base): def _visit_loop_body(self, node, nodes): self.enter_local_scope() scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - continue_var = self.context.namer.new_symbol('continue_', scope.referenced) + continue_var = self.ctx.namer.new_symbol('continue_', scope.referenced) self.set_local(CONTROL_VAR_NAME, continue_var) nodes = self.visit_block(nodes, after_visit=self._postprocess_statement) @@ -135,5 +135,5 @@ class ContinueCanonicalizationTransformer(transformer.Base): return node -def transform(node, namer): - return ContinueCanonicalizationTransformer(namer).visit(node) +def transform(node, ctx): + return ContinueCanonicalizationTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/continue_statements_test.py b/tensorflow/contrib/autograph/converters/continue_statements_test.py index bcbb316d7459aa5a25bb0bd128cd6e359a393288..2ce1837972c50bbc4921487a290f5cb2f782b5f3 100644 --- a/tensorflow/contrib/autograph/converters/continue_statements_test.py +++ b/tensorflow/contrib/autograph/converters/continue_statements_test.py @@ -19,11 +19,11 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.autograph.converters import continue_statements -from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.platform import test -class ContinueCanonicalizationTest(converter_test_base.TestCase): +class ContinueCanonicalizationTest(converter_testing.TestCase): def test_basic_continue(self): diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py index d7ddbe8a04f64848d6ec21155d8d85f60e19d276..f4a87106279d5658ecaa90a577cbe741711ba22e 100644 --- a/tensorflow/contrib/autograph/converters/control_flow.py +++ b/tensorflow/contrib/autograph/converters/control_flow.py @@ -20,11 +20,11 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis import cfg from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno @@ -45,9 +45,8 @@ class SymbolNamer(object): raise NotImplementedError() -class ControlFlowTransformer(transformer.Base): +class ControlFlowTransformer(converter.Base): """Transforms control flow structures like loops an conditionals.""" - def _create_cond_branch(self, body_name, aliased_orig_names, aliased_new_names, body, returns): if aliased_orig_names: @@ -141,10 +140,10 @@ class ControlFlowTransformer(transformer.Base): aliased_orelse_orig_names = tuple(orelse_scope.modified - orelse_scope.created) aliased_body_new_names = tuple( - self.context.namer.new_symbol(s.ssf(), body_scope.referenced) + self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) for s in aliased_body_orig_names) aliased_orelse_new_names = tuple( - self.context.namer.new_symbol(s.ssf(), orelse_scope.referenced) + self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) for s in aliased_orelse_orig_names) alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) @@ -165,9 +164,8 @@ class ControlFlowTransformer(transformer.Base): else: results = gast.Tuple([s.ast() for s in modified], None) - body_name = self.context.namer.new_symbol('if_true', body_scope.referenced) - orelse_name = self.context.namer.new_symbol('if_false', - orelse_scope.referenced) + body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) + orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) if modified: def build_returns(aliased_names, alias_map, scope): @@ -235,7 +233,7 @@ class ControlFlowTransformer(transformer.Base): raise ValueError('cannot convert while loop: no outputs') state_ssf = [ - self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state + self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf @@ -267,11 +265,9 @@ class ControlFlowTransformer(transformer.Base): state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, - test_name=self.context.namer.new_symbol('loop_test', - body_scope.referenced), + test_name=self.ctx.namer.new_symbol('loop_test', body_scope.referenced), test=test, - body_name=self.context.namer.new_symbol('loop_body', - body_scope.referenced), + body_name=self.ctx.namer.new_symbol('loop_body', body_scope.referenced), body=node_body, extra_deps=tuple(s.ast() for s in cond_closure), ) @@ -288,7 +284,7 @@ class ControlFlowTransformer(transformer.Base): state = list(body_closure) state_ssf = [ - self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state + self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state ] ssf_map = { name: ssf @@ -326,17 +322,16 @@ class ControlFlowTransformer(transformer.Base): state_ast_tuple=state_ast_tuple, iter_=node.iter, iterate=node.target, - extra_test_name=self.context.namer.new_symbol('extra_test', - all_referenced), + extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced), extra_test_expr=extra_test, - body_name=self.context.namer.new_symbol('loop_body', all_referenced), + body_name=self.ctx.namer.new_symbol('loop_body', all_referenced), body=node_body) return node -def transform(node, context): - cfg.run_analyses(node, cfg.Liveness(context)) - cfg.run_analyses(node, cfg.Defined(context)) - node = ControlFlowTransformer(context).visit(node) +def transform(node, ctx): + cfg.run_analyses(node, cfg.Liveness(ctx.info)) + cfg.run_analyses(node, cfg.Defined(ctx.info)) + node = ControlFlowTransformer(ctx).visit(node) return node diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py index 9d23d9b5b7e8e8480e04fccc1c8c81799abf382b..735eb92a0dd06ee7fd621b92b1a8f894e09cee4a 100644 --- a/tensorflow/contrib/autograph/converters/control_flow_test.py +++ b/tensorflow/contrib/autograph/converters/control_flow_test.py @@ -19,7 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.autograph.converters import control_flow -from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -27,7 +27,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import test -class ControlFlowTest(converter_test_base.TestCase): +class ControlFlowTest(converter_testing.TestCase): def test_simple_while(self): diff --git a/tensorflow/contrib/autograph/converters/decorators.py b/tensorflow/contrib/autograph/converters/decorators.py index 92445f31746cf94856ea43893f99a2ba60355fb5..3471bd11d6073f57a2703b438df95a60f19e8e0c 100644 --- a/tensorflow/contrib/autograph/converters/decorators.py +++ b/tensorflow/contrib/autograph/converters/decorators.py @@ -24,19 +24,14 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import pretty_printer +from tensorflow.python.util import tf_inspect -class DecoratorsTransformer(gast.NodeTransformer): +class DecoratorsTransformer(converter.Base): """Converts or removes decorators.""" - def __init__(self, remove_decorators): - self.remove_decorators = remove_decorators - self.additional_dependencies = set() - - # pylint:disable=invalid-name - def visit_FunctionDef(self, node): self.generic_visit(node) kept_decorators = [] @@ -58,31 +53,53 @@ class DecoratorsTransformer(gast.NodeTransformer): # This is currently verified by tests. continue - if not anno.hasanno(dec_func, 'live_val'): - raise ValueError( - 'Could not resolve decorator: %s' % pretty_printer.fmt(dec_func)) - + original_dec = anno.getanno(dec_func, anno.Basic.QN) dec_value = anno.getanno(dec_func, 'live_val') - if dec_value not in self.remove_decorators: - kept_decorators.append((dec, dec_value)) - for _, dec_value in kept_decorators: - if dec_value.__module__ == '__main__': + if dec_value in self.ctx.program.autograph_decorators: + # AutoGraph decorators do not need to be preserved. + continue + + # When using foo.bar.baz, we only really need to grab foo and import + # that. + dec_support_node = dec_func + while isinstance(dec_support_node, gast.Attribute): + dec_support_node = dec_support_node.value + + if not anno.hasanno(dec_support_node, 'live_val'): raise ValueError( - 'decorator "%s" was not allowed because it is declared ' - 'in the module "%s". To fix this, declare it in a separate ' - 'module that we can import it from.' % (dec_value, - dec_value.__module__)) + 'could not resolve symbol "%s" when looking up decorator "%s"' % + (anno.getanno(dec_support_node, anno.Basic.QN), original_dec)) + + dec_support = anno.getanno(dec_support_node, 'live_val') + # The tuple contains: + # * the AST that represents the decorator + # * the entity supporting the decorator (i.e., what we need to import) + # * the name of the module that needs to be imported for this decorator + # to properly resolve. + # Examples: + # for foo.bar, the tuple is (, , 'foo') + # for baz, the tuple is (, , 'baz') + kept_decorators.append((dec, dec_support, + anno.getanno(dec_support_node, anno.Basic.QN))) + + for _, dec_support, name in kept_decorators: + if tf_inspect.ismodule(dec_support): + self.ctx.program.additional_imports.add( + 'import %s as %s' % (dec_support.__name__, name)) else: - self.additional_dependencies.add(dec_value) - - node.decorator_list = [dec for dec, _ in kept_decorators] + if dec_support.__module__ == '__main__': + raise ValueError( + 'decorator "%s" was not allowed because it is declared ' + 'in the module "%s". To fix this, declare it in a separate ' + 'module that we can import it from.' % (dec_support, + dec_support.__module__)) + self.ctx.program.additional_imports.add( + 'from %s import %s' % (dec_support.__module__, name)) + + node.decorator_list = [dec for dec, _, _ in kept_decorators] return node - # pylint:enable=invalid-name - -def transform(node, remove_decorators): - transformer = DecoratorsTransformer(remove_decorators) - node = transformer.visit(node) - return node, transformer.additional_dependencies +def transform(node, ctx): + return DecoratorsTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/decorators_test.py b/tensorflow/contrib/autograph/converters/decorators_test.py index 9c01f689127dbedad7669c65b03e7da071b2d64d..d41c7fde2474803a438100e7e00ce8e9f675de45 100644 --- a/tensorflow/contrib/autograph/converters/decorators_test.py +++ b/tensorflow/contrib/autograph/converters/decorators_test.py @@ -20,9 +20,10 @@ from __future__ import print_function from functools import wraps -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import decorators +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.platform import test @@ -39,28 +40,35 @@ def simple_decorator(f): return lambda a: f(a) + 1 -def self_removing_decorator(removing_wrapper): +def self_transform_decorator(transform): + def decorator(f): @wraps(f) def wrapper(*args): # This removing wrapper is defined in the test below. This setup is so - # intricate just to simulate how we use the transformer in practice. - transformed_f = removing_wrapper(f, (self_removing_decorator,)) + # intricate in order to simulate how we use the transformer in practice. + transformed_f = transform(f, (self_transform_decorator,)) return transformed_f(*args) + 1 return wrapper return decorator -class DecoratorsTest(converter_test_base.TestCase): +class DecoratorsTest(converter_testing.TestCase): - def _remover_wrapper(self, f, remove_decorators): + def _transform(self, f, autograph_decorators): namespace = { - 'self_removing_decorator': self_removing_decorator, - 'simple_decorator': simple_decorator + 'self_transform_decorator': self_transform_decorator, + 'simple_decorator': simple_decorator, + 'converter_testing': converter_testing, } - node = self.parse_and_analyze(f, namespace) - node, _ = decorators.transform(node, remove_decorators=remove_decorators) - result, _ = compiler.ast_to_object(node) + node = self.parse_and_analyze( + f, + namespace, + recursive=False, + autograph_decorators=autograph_decorators) + node = decorators.transform(node, self.ctx) + import_line = '\n'.join(self.ctx.program.additional_imports) + result, _ = compiler.ast_to_object(node, source_prefix=import_line) return getattr(result, f.__name__) def test_noop(self): @@ -69,15 +77,14 @@ class DecoratorsTest(converter_test_base.TestCase): return a node = self.parse_and_analyze(test_fn, {}) - node, deps = decorators.transform(node, remove_decorators=()) + node = decorators.transform(node, self.ctx) result, _ = compiler.ast_to_object(node) - self.assertFalse(deps) self.assertEqual(1, result.test_fn(1)) def test_function(self): - @self_removing_decorator(self._remover_wrapper) + @self_transform_decorator(self._transform) def test_fn(a): return a @@ -88,7 +95,7 @@ class DecoratorsTest(converter_test_base.TestCase): class TestClass(object): - @self_removing_decorator(self._remover_wrapper) + @self_transform_decorator(self._transform) def test_fn(self, a): return a @@ -101,38 +108,39 @@ class DecoratorsTest(converter_test_base.TestCase): # Note that reversing the order of this two doesn't work. @classmethod - @self_removing_decorator(self._remover_wrapper) + @self_transform_decorator(self._transform) def test_fn(cls, a): return a # 2 = 1 (a) + 1 (decorator applied exactly once) self.assertEqual(2, TestClass.test_fn(1)) - def test_nested_decorators(self): + def test_nested_decorators_local(self): - @self_removing_decorator(self._remover_wrapper) + @self_transform_decorator(self._transform) def test_fn(a): @simple_decorator def inner_fn(b): return b + 11 return inner_fn(a) - with self.assertRaises(ValueError): + # Expected to fail because simple_decorator cannot be imported. + with self.assertRaises(transformer.AutographParseError): test_fn(1) - # TODO(mdan): Uncomment this test once converter_test_base is updated. - # (can't do it now because it has unrelated pending changes) - # def test_nested_decorators(self): - # - # @self_removing_decorator(self._remover_wrapper) - # def test_fn(a): - # @imported_decorator - # def inner_fn(b): - # return b + 11 - # return inner_fn(a) - # - # # 14 = 1 (a) + 1 (simple_decorator) + 11 (inner_fn) - # self.assertEqual(14, test_fn(1)) + def test_nested_decorators_imported(self): + + @self_transform_decorator(self._transform) + def test_fn(a): + + @converter_testing.imported_decorator + def inner_fn(b): + return b + 11 + + return inner_fn(a) + + # 14 = 1 (a) + 1 (simple_decorator) + 11 (inner_fn) + self.assertEqual(14, test_fn(1)) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/ifexp.py b/tensorflow/contrib/autograph/converters/ifexp.py index 616d222762e09feeba1809f119d915dfbe522283..e996138498ab2b7efa76671d8cc67fd4c6a9d9b8 100644 --- a/tensorflow/contrib/autograph/converters/ifexp.py +++ b/tensorflow/contrib/autograph/converters/ifexp.py @@ -18,11 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer -class IfExp(transformer.Base): +class IfExp(converter.Base): """Canonicalizes all IfExp nodes into plain conditionals.""" def visit_IfExp(self, node): @@ -34,16 +34,16 @@ class IfExp(transformer.Base): return desugared_ifexp -def transform(node, context): +def transform(node, ctx): """Desugar IfExp nodes into plain conditionals. Args: - node: an AST node to transform - context: a context object + node: ast.AST, the node to transform + ctx: converter.EntityContext Returns: new_node: an AST with no IfExp nodes, only conditionals. """ - node = IfExp(context).visit(node) + node = IfExp(ctx).visit(node) return node diff --git a/tensorflow/contrib/autograph/converters/ifexp_test.py b/tensorflow/contrib/autograph/converters/ifexp_test.py index ac6849dcb4bd7dacd84bb205f5c65395d8c2f51e..cdd5a2f591edc1138df1c165577ed375131ddf09 100644 --- a/tensorflow/contrib/autograph/converters/ifexp_test.py +++ b/tensorflow/contrib/autograph/converters/ifexp_test.py @@ -19,12 +19,12 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.autograph import utils -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import ifexp +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.platform import test -class IfExpTest(converter_test_base.TestCase): +class IfExpTest(converter_testing.TestCase): def compiled_fn(self, test_fn, *args): node = self.parse_and_analyze(test_fn, {}) diff --git a/tensorflow/contrib/autograph/converters/list_comprehension.py b/tensorflow/contrib/autograph/converters/list_comprehension.py index d7f292015164e047d054c5d1fb0b391e960bb73d..c4a13ee822ab84706df83256d9e9684c3f7dacba 100644 --- a/tensorflow/contrib/autograph/converters/list_comprehension.py +++ b/tensorflow/contrib/autograph/converters/list_comprehension.py @@ -31,17 +31,14 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer -class ListCompCanonicalizationTransformer(transformer.Base): +class ListCompCanonicalizationTransformer(converter.Base): """NodeTransformer to canonicalize list comprehensions.""" - def __init__(self, context): - super(ListCompCanonicalizationTransformer, self).__init__(context) - def make_update_list_node(self, list_, elt): return templates.replace('list_.append(elt)', list_=list_, elt=elt)[0] @@ -76,5 +73,5 @@ class ListCompCanonicalizationTransformer(transformer.Base): return make_list + loop_body -def transform(node, context): - return ListCompCanonicalizationTransformer(context).visit(node) +def transform(node, ctx): + return ListCompCanonicalizationTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/list_comprehension_test.py b/tensorflow/contrib/autograph/converters/list_comprehension_test.py index 4758671f5ec83c26cfa54be0ef68f5f564094f6c..2bbee93412ce3174a14f3d60af9435dcf3b82cc6 100644 --- a/tensorflow/contrib/autograph/converters/list_comprehension_test.py +++ b/tensorflow/contrib/autograph/converters/list_comprehension_test.py @@ -18,12 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import list_comprehension +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.platform import test -class ListCompTest(converter_test_base.TestCase): +class ListCompTest(converter_testing.TestCase): def test_basic(self): diff --git a/tensorflow/contrib/autograph/converters/lists.py b/tensorflow/contrib/autograph/converters/lists.py index b49521b2c328f418828a5e92890aa1b169384b70..d77a04479826779b8aa859d70f2f7ff51138f841 100644 --- a/tensorflow/contrib/autograph/converters/lists.py +++ b/tensorflow/contrib/autograph/converters/lists.py @@ -32,85 +32,196 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer -from tensorflow.python.framework import dtypes +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno -class ListTransformer(transformer.Base): +# Tags for local state. +POP_USES = 'pop_uses' + + +class ListTransformer(converter.Base): """Converts lists and related operations to their TF counterpart.""" - def _empty_list(self, node): - if not anno.hasanno(node, 'element_type'): - raise NotImplementedError( - 'type inference for empty lists is not yet supported; ' - 'use set_element_type(, ) to continue') - dtype = anno.getanno(node, 'element_type') - if not isinstance(dtype, dtypes.DType): - # TODO(mdan): Allow non-TF dtypes? - # That would be consistent with the dynamic dispatch pattern, but - # we must make sure that doesn't become confusing. - raise NotImplementedError('element type "%s" not yet supported' % dtype) - - dtype_name = dtype.name - # TODO(mdan): Does it ever make sense not to use tensor lists? + def visit_List(self, node): + node = self.generic_visit(node) template = """ - tf.TensorArray(tf.dtype_name, size=0, dynamic_size=True) + ag__.new_list(elements) """ - return templates.replace_as_expression(template, dtype_name=dtype_name) + return templates.replace_as_expression(template, elements=node) - def _pre_populated_list(self, node): - raise NotImplementedError('pre-populated lists') + def _replace_append_call(self, node): + assert len(node.args) == 1 + assert isinstance(node.func, gast.Attribute) + template = """ + target = ag__.list_append(target, element) + """ + return templates.replace( + template, + target=node.func.value, + element=node.args[0]) + + def _replace_pop_call(self, node): + # Expressions that use pop() are converted to a statement + expression. + # + # For example: + # + # print(target.pop()) + # + # ... is converted to: + # + # target, target_pop = ag__.list_pop(target) + # print(target_pop) + # + # Here, we just generate the variable name and swap it in, + # and _generate_pop_operation will handle the rest. + # + # Multiple uses of pop() are allowed: + # + # print(tartget.pop(), target.pop()) + # print(tartget.pop().pop()) + # + assert isinstance(node.func, gast.Attribute) + scope = anno.getanno(node, NodeAnno.ARGS_SCOPE) + target_node = node.func.value + + # Attempt to use a related name if can get one. Otherwise use something + # generic. + if anno.hasanno(target_node, anno.Basic.QN): + target_name = anno.getanno(target_node, anno.Basic.QN).ssf() + else: + target_name = 'list' + pop_var_name = self.ctx.namer.new_symbol(target_name, scope.referenced) + + pop_uses = self.get_local(POP_USES, []) + pop_uses.append((node, pop_var_name)) + self.set_local(POP_USES, pop_uses) + + return templates.replace_as_expression('var_name', var_name=pop_var_name) + + def _replace_stack_call(self, node): + assert len(node.args) == 1 + dtype = anno.getanno( + node.args[0], + 'element_type', + default=templates.replace_as_expression('None')) + template = """ + ag__.list_stack( + target, + opts=ag__.ListStackOpts( + element_dtype=dtype, + original_call=orig_call)) + """ + return templates.replace_as_expression( + template, + dtype=dtype, + target=node.args[0], + orig_call=node.func) - def visit_Expr(self, node): + def visit_Call(self, node): node = self.generic_visit(node) - if isinstance(node.value, gast.Call): - call_node = node.value - - if not anno.hasanno(call_node.func, anno.Basic.QN): - return node - qn = anno.getanno(call_node.func, anno.Basic.QN) - - if qn.qn[-1] == 'append' and (len(call_node.args) == 1): - template = """ - target = ag__.utils.dynamic_list_append(target, element) - """ - node = templates.replace( - template, - target=qn.parent.ast(), - element=call_node.args[0]) + + # TODO(mdan): This is insufficient if target is a function argument. + # In the case of function arguments, we need to add the list to the + # function's return value, because it is being modified. + # TODO(mdan): Checking just the name is brittle, can it be improved? + if isinstance(node.func, gast.Attribute): + func_name = node.func.attr + if func_name == 'append' and (len(node.args) == 1): + node = self._replace_append_call(node) + elif func_name == 'pop' and (len(node.args) <= 1): + node = self._replace_pop_call(node) + elif func_name == 'stack' and (len(node.args) == 1): + node = self._replace_stack_call(node) + return node - def _replace_list_constructors(self, targets, values): - for target in targets: - if (isinstance(target, (gast.Tuple, gast.List)) and - isinstance(values, (gast.Tuple, gast.List))): - n_targets = len(target.elts) - for i in range(n_targets): - target_el, value_el = target.elts[i], values.elts[i] - values.elts[i] = self._replace_list_constructors( - (target_el,), value_el) - return values - if isinstance(values, gast.List): - if values.elts: - return self._pre_populated_list(values) - else: - return self._empty_list(values) - return values - - def visit_Assign(self, node): - node = self.generic_visit(node) + def _generate_pop_operation(self, original_call_node, pop_var_name): + assert isinstance(original_call_node.func, gast.Attribute) + + if original_call_node.args: + pop_element = original_call_node.args[0] + else: + pop_element = parser.parse_expression('None') + # The call will be something like "target.pop()", and the dtype is hooked to + # target, hence the func.value. + dtype = anno.getanno( + original_call_node.func.value, + 'element_type', + default=templates.replace_as_expression('None')) + shape = anno.getanno( + original_call_node.func.value, + 'element_shape', + default=templates.replace_as_expression('None')) + + template = """ + target, pop_var_name = ag__.list_pop( + target, element, + opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape)) + """ + return templates.replace( + template, + target=original_call_node.func.value, + pop_var_name=pop_var_name, + element=pop_element, + dtype=dtype, + shape=shape) + + def _postprocess_statement(self, node): + """Inserts any separate pop() calls that node may use.""" + pop_uses = self.get_local(POP_USES, None) + if pop_uses: + replacements = [] + for original_call_node, pop_var_name in pop_uses: + replacements.extend( + self._generate_pop_operation(original_call_node, pop_var_name)) + replacements.append(node) + node = replacements + self.exit_local_scope() + return node, None + + # TODO(mdan): Should we have a generic visit_block instead? + # Right now it feels that a visit_block would add too much magic that's + # hard to follow. + + def _visit_and_process_block(self, block): + return self.visit_block( + block, + before_visit=self.enter_local_scope, + after_visit=self._postprocess_statement) + + def visit_FunctionDef(self, node): + node.args = self.generic_visit(node.args) + node.decorator_list = self.visit_block(node.decorator_list) + node.body = self._visit_and_process_block(node.body) + return node + + def visit_For(self, node): + node.target = self.visit(node.target) + node.body = self._visit_and_process_block(node.body) + node.orelse = self._visit_and_process_block(node.orelse) + return node + + def visit_While(self, node): + node.test = self.visit(node.test) + node.body = self._visit_and_process_block(node.body) + node.orelse = self._visit_and_process_block(node.orelse) + return node + + def visit_If(self, node): + node.test = self.visit(node.test) + node.body = self._visit_and_process_block(node.body) + node.orelse = self._visit_and_process_block(node.orelse) + return node - # Only convert lists when they are assigned to a variable, e.g.: - # l = [] - # TODO(mdan): A similar pattern exists in type_info.py - # We should add a generic "unpack_assignment" function to the base - # transformer, that has the same effect as applying some logic to the SSA - # form. - node.value = self._replace_list_constructors(node.targets, node.value) + def visit_With(self, node): + node.items = self.visit_block(node.items) + node.body = self._visit_and_process_block(node.body) return node -def transform(node, context): - return ListTransformer(context).visit(node) +def transform(node, ctx): + return ListTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/lists_test.py b/tensorflow/contrib/autograph/converters/lists_test.py index 74c6dc64f197f75eb3e66c01fb078467e8e8ea89..ea04097b28deedd705164bd95ab62dba3e3c7834 100644 --- a/tensorflow/contrib/autograph/converters/lists_test.py +++ b/tensorflow/contrib/autograph/converters/lists_test.py @@ -19,77 +19,129 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.autograph import utils -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import lists +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import dtypes -from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import list_ops from tensorflow.python.platform import test -class ListTest(converter_test_base.TestCase): +class ListTest(converter_testing.TestCase): - def test_empty_annotated_list(self): + def test_empty_list(self): def test_fn(): - l = [] - utils.set_element_type(l, dtypes.int32) - l.append(1) - return l + return [] - node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils}) + node = self.parse_and_analyze(test_fn, {}) node = lists.transform(node, self.ctx) - with self.compiled(node, tensor_array_ops.TensorArray, - dtypes.int32) as result: - # TODO(mdan): Attach these additional modules automatically. - result.utils = utils - result.dtypes = dtypes + with self.compiled(node) as result: + tl = result.test_fn() + # Empty tensor lists cannot be evaluated or stacked. + self.assertTrue(isinstance(tl, ops.Tensor)) + self.assertEqual(tl.dtype, dtypes.variant) + + def test_initialized_list(self): + + def test_fn(): + return [1, 2, 3] + + node = self.parse_and_analyze(test_fn, {}) + node = lists.transform(node, self.ctx) + + with self.compiled(node) as result: with self.test_session() as sess: - self.assertAllEqual([1], sess.run(result.test_fn().stack())) + tl = result.test_fn() + r = list_ops.tensor_list_stack(tl, dtypes.int32) + self.assertAllEqual(sess.run(r), [1, 2, 3]) - def test_empty_annotated_lists_unpacked(self): + def test_list_append(self): def test_fn(): - l, m = [], [] - utils.set_element_type(l, dtypes.int32) - utils.set_element_type(m, dtypes.int32) - l.append(1) - m.append(2) - return l, m + l = [1] + l.append(2) + l.append(3) + return l - node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils}) + node = self.parse_and_analyze(test_fn, {}) node = lists.transform(node, self.ctx) - with self.compiled(node, tensor_array_ops.TensorArray, - dtypes.int32) as result: + with self.compiled(node) as result: + with self.test_session() as sess: + tl = result.test_fn() + r = list_ops.tensor_list_stack(tl, dtypes.int32) + self.assertAllEqual(sess.run(r), [1, 2, 3]) + + def test_list_pop(self): + + def test_fn(): + l = [1, 2, 3] + utils.set_element_type(l, dtypes.int32, ()) + s = l.pop() + return s, l + + node = self.parse_and_analyze( + test_fn, + { + 'utils': utils, + 'dtypes': dtypes + }, + include_type_analysis=True, + ) + node = lists.transform(node, self.ctx) + + with self.compiled(node) as result: result.utils = utils result.dtypes = dtypes with self.test_session() as sess: - res_l, res_m = result.test_fn() - self.assertEqual([1], sess.run(res_l.stack())) - self.assertEqual([2], sess.run(res_m.stack())) + ts, tl = result.test_fn() + r = list_ops.tensor_list_stack(tl, dtypes.int32) + self.assertAllEqual(sess.run(r), [1, 2]) + self.assertAllEqual(sess.run(ts), 3) + + def test_double_list_pop(self): - def test_empty_annotated_lists_list_unpacked(self): + def test_fn(l): + s = l.pop().pop() + return s + + node = self.parse_and_analyze(test_fn, {}) + node = lists.transform(node, self.ctx) + + with self.compiled(node) as result: + test_input = [1, 2, [1, 2, 3]] + # TODO(mdan): Pass a list of lists of tensor when we fully support that. + # For now, we just pass a regular Python list of lists just to verify that + # the two pop calls are sequenced properly. + self.assertAllEqual(result.test_fn(test_input), 3) + + def test_list_stack(self): + + tf = None # Will be replaced with a mock. def test_fn(): - [l, m] = [], [] + l = [1, 2, 3] utils.set_element_type(l, dtypes.int32) - utils.set_element_type(m, dtypes.int32) - l.append(1) - m.append(2) - return l, m - - node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils}) + return tf.stack(l) + + node = self.parse_and_analyze( + test_fn, + { + 'utils': utils, + 'dtypes': dtypes + }, + include_type_analysis=True, + ) node = lists.transform(node, self.ctx) - with self.compiled(node, tensor_array_ops.TensorArray, - dtypes.int32) as result: + with self.compiled(node, array_ops.stack, dtypes.int32) as result: result.utils = utils result.dtypes = dtypes with self.test_session() as sess: - res_l, res_m = result.test_fn() - self.assertEqual([1], sess.run(res_l.stack())) - self.assertEqual([2], sess.run(res_m.stack())) + self.assertAllEqual(sess.run(result.test_fn()), [1, 2, 3]) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/logical_expressions.py b/tensorflow/contrib/autograph/converters/logical_expressions.py index 3a795a315a3c2aa08ac1577a204102755b6e849c..16eb1f0e3f8ad34e615931882ab2896db485f457 100644 --- a/tensorflow/contrib/autograph/converters/logical_expressions.py +++ b/tensorflow/contrib/autograph/converters/logical_expressions.py @@ -23,10 +23,10 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer # TODO(mdan): Properly extrack boolean ops according to lazy eval rules. @@ -39,11 +39,11 @@ from tensorflow.contrib.autograph.pyct import transformer SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND' -class LogicalExpressionTransformer(transformer.Base): +class LogicalExpressionTransformer(converter.Base): """Converts logical expressions to corresponding TF calls.""" - def __init__(self, context): - super(LogicalExpressionTransformer, self).__init__(context) + def __init__(self, ctx): + super(LogicalExpressionTransformer, self).__init__(ctx) # TODO(mdan): Look into replacing with bitwise operators instead. # TODO(mdan): Skip replacing if the function is trivial. self.op_mapping = { @@ -128,5 +128,5 @@ class LogicalExpressionTransformer(transformer.Base): return right -def transform(node, context): - return LogicalExpressionTransformer(context).visit(node) +def transform(node, ctx): + return LogicalExpressionTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/logical_expressions_test.py b/tensorflow/contrib/autograph/converters/logical_expressions_test.py index 2814060c4d831e4dddacb3dcbcbe1db42160db20..48186024a9da7b41fa7ff9a8ab18f3477ba09c8f 100644 --- a/tensorflow/contrib/autograph/converters/logical_expressions_test.py +++ b/tensorflow/contrib/autograph/converters/logical_expressions_test.py @@ -18,13 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import logical_expressions +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class GradientsFunctionTest(converter_test_base.TestCase): +class GradientsFunctionTest(converter_testing.TestCase): def test_equals(self): diff --git a/tensorflow/contrib/autograph/converters/name_scopes.py b/tensorflow/contrib/autograph/converters/name_scopes.py index dfee529abaa8c14d9b408819b32c5199500a2c2f..dd6c6bf960c52d094a16d4cd72fa84f65b9322a1 100644 --- a/tensorflow/contrib/autograph/converters/name_scopes.py +++ b/tensorflow/contrib/autograph/converters/name_scopes.py @@ -20,11 +20,11 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer -class FunctionNameScopeTransformer(transformer.Base): +class FunctionNameScopeTransformer(converter.Base): """Wrap a function body with a `name_scope` of the function name.""" def _name_for_current_scope(self): @@ -70,5 +70,5 @@ class FunctionNameScopeTransformer(transformer.Base): return node -def transform(node, context): - return FunctionNameScopeTransformer(context).visit(node) +def transform(node, ctx): + return FunctionNameScopeTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/name_scopes_test.py b/tensorflow/contrib/autograph/converters/name_scopes_test.py index 17692cbd880dbc1db4bb40ad7345e27907499f9d..444d0bcd469f35689d078debe3622f930dbac723 100644 --- a/tensorflow/contrib/autograph/converters/name_scopes_test.py +++ b/tensorflow/contrib/autograph/converters/name_scopes_test.py @@ -18,14 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import name_scopes +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.platform import test -class FunctionNameScopeTransformer(converter_test_base.TestCase): +class FunctionNameScopeTransformer(converter_testing.TestCase): def test_basic(self): diff --git a/tensorflow/contrib/autograph/converters/side_effect_guards.py b/tensorflow/contrib/autograph/converters/side_effect_guards.py index 3bcb2d3c42c6e0663c8f78523199a364b6ac231f..b808604f0ab2d42f41a560035ab046ff782a3431 100644 --- a/tensorflow/contrib/autograph/converters/side_effect_guards.py +++ b/tensorflow/contrib/autograph/converters/side_effect_guards.py @@ -36,11 +36,11 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import qual_names from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno @@ -59,14 +59,9 @@ class SymbolNamer(object): raise NotImplementedError() -class SideEffectGuardTransformer(transformer.Base): +class SideEffectGuardTransformer(converter.Base): """Adds control dependencies to functions with side effects.""" - def __init__(self, context): - super(SideEffectGuardTransformer, self).__init__(context) - - # pylint:disable=invalid-name - def _visit_and_reindent(self, nodes): new_nodes = [] current_dest = new_nodes @@ -149,7 +144,7 @@ class SideEffectGuardTransformer(transformer.Base): s for s in guarded_args if s not in args_scope.parent.modified) aliased_new_names = tuple( qual_names.QN( - self.context.namer.new_symbol( + self.ctx.namer.new_symbol( s.ssf(), args_scope.parent.referenced)) for s in need_alias) alias_map = dict(zip(need_alias, aliased_new_names)) if len(guarded_args) == 1: @@ -183,8 +178,6 @@ class SideEffectGuardTransformer(transformer.Base): (node.body, alias_map)) return node - # pylint:enable=invalid-name - -def transform(node, context): - return SideEffectGuardTransformer(context).visit(node) +def transform(node, ctx): + return SideEffectGuardTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py index ce0ce33243a1352107eb8121050ee76474869809..a7ad8efed4c88e15ce9dc14cb02e5e035602013d 100644 --- a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py +++ b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import side_effect_guards +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops @@ -29,7 +29,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test -class SideEffectGuardsTest(converter_test_base.TestCase): +class SideEffectGuardsTest(converter_testing.TestCase): def test_side_effect_on_return_only_variable(self): diff --git a/tensorflow/contrib/autograph/converters/single_return.py b/tensorflow/contrib/autograph/converters/single_return.py index bcc9ca9dfeb00ef2d2e60edf6a1abfba19a1bad7..a351cd81b82f7fb32f62ac1579355ace0501759d 100644 --- a/tensorflow/contrib/autograph/converters/single_return.py +++ b/tensorflow/contrib/autograph/converters/single_return.py @@ -20,21 +20,21 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import templates -from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno # TODO(mdan): Move this logic into transformer_base. -class BodyVisitor(transformer.Base): +class BodyVisitor(converter.Base): """Walks breadth- or depth-first the list-of-nodes bodies of AST nodes.""" - def __init__(self, context, depth_first=False): + def __init__(self, ctx, depth_first=False): + super(BodyVisitor, self).__init__(ctx) self.depth_first = depth_first self.changes_made = False - super(BodyVisitor, self).__init__(context) def visit_nodelist(self, nodelist): for node in nodelist: @@ -144,13 +144,13 @@ def contains_return(node): return False -class LiftReturn(transformer.Base): +class LiftReturn(converter.Base): """Move return statements out of If and With blocks.""" - def __init__(self, context): + def __init__(self, ctx): + super(LiftReturn, self).__init__(ctx) self.changes_made = False self.common_return_name = None - super(LiftReturn, self).__init__(context) def visit_If(self, node): # Depth-first traversal of if statements @@ -195,8 +195,8 @@ class LiftReturn(transformer.Base): last_return_name = self.common_return_name body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) referenced_names = body_scope.referenced - self.common_return_name = self.context.namer.new_symbol( - 'return_', referenced_names) + self.common_return_name = self.ctx.namer.new_symbol('return_', + referenced_names) node = self.generic_visit(node) self.common_return_name = last_return_name return node @@ -265,7 +265,7 @@ class DetectReturnInFunctionDef(gast.NodeVisitor): 'Each function definition should contain at least one return.') -def transform(node, context): +def transform(node, ctx): """Ensure a function has only a single return. This transforms an AST node with multiple returns successively into containing @@ -280,8 +280,8 @@ def transform(node, context): this is an error. Args: - node: an AST node to transform - context: a context object + node: ast.AST + ctx: converter.EntityContext Returns: new_node: an AST with a single return value @@ -301,10 +301,10 @@ def transform(node, context): while True: # Try to lift all returns out of if statements and with blocks - lr = LiftReturn(context) + lr = LiftReturn(ctx) node = lr.visit(node) changes_made = lr.changes_made - fe = FoldElse(context) + fe = FoldElse(ctx) node = fe.visit(node) changes_made = changes_made or fe.changes_made diff --git a/tensorflow/contrib/autograph/converters/single_return_test.py b/tensorflow/contrib/autograph/converters/single_return_test.py index d483005a09537ea8227814f65aa7e6402c853f60..1f0de4310e370235a4a7bfeaa61bd519a81aff47 100644 --- a/tensorflow/contrib/autograph/converters/single_return_test.py +++ b/tensorflow/contrib/autograph/converters/single_return_test.py @@ -18,13 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import single_return +from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework.ops import name_scope from tensorflow.python.platform import test -class SingleReturnTest(converter_test_base.TestCase): +class SingleReturnTest(converter_testing.TestCase): def compiled_fn(self, test_fn, *args): node = self.parse_and_analyze(test_fn, {}) diff --git a/tensorflow/contrib/autograph/converters/slices.py b/tensorflow/contrib/autograph/converters/slices.py new file mode 100644 index 0000000000000000000000000000000000000000..3f5fc57125a8b65faf1e3a377d7984ff05b3245c --- /dev/null +++ b/tensorflow/contrib/autograph/converters/slices.py @@ -0,0 +1,83 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Converter for slice operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import templates + + +class SliceTransformer(converter.Base): + """Converts slicing operations to their TF counterpart. + + Currently, relying on the default slice operator that Tensor uses is + insufficient, because TensorArray and tensor lists use dedicated index read + and write functions. + """ + + def _process_single_assignment(self, target, value): + if not isinstance(target, gast.Subscript): + return None + + template = """ + target = ag__.set_item(target, key, item) + """ + return templates.replace( + template, target=target.value, key=target.slice, item=value) + + def visit_Assign(self, node): + node = self.generic_visit(node) + # TODO(mdan): Support unpackings and multiple assignments. + if len(node.targets) != 1: + raise NotImplementedError('multiple assignment') + replacement = self._process_single_assignment(node.targets[0], node.value) + if replacement is not None: + return replacement + return node + + def visit_Subscript(self, node): + node = self.generic_visit(node) + if not isinstance(node.slice, gast.Index): + # TODO(mdan): It might make more sense to wave them through. + raise NotImplementedError('non-index slice') + + if not isinstance(node.ctx, gast.Load): + # Index writes are handled at a higher level, one at which the rvalue is + # also available. + return node + + dtype = anno.getanno( + node.value, + 'element_type', + default=templates.replace_as_expression('None')) + + template = """ + ag__.get_item( + target, + key, + opts=ag__.GetItemOpts(element_dtype=dtype)) + """ + return templates.replace_as_expression( + template, target=node.value, key=node.slice, dtype=dtype) + + +def transform(node, ctx): + return SliceTransformer(ctx).visit(node) diff --git a/tensorflow/contrib/autograph/converters/slices_test.py b/tensorflow/contrib/autograph/converters/slices_test.py new file mode 100644 index 0000000000000000000000000000000000000000..df9a4c8bab66f24374605b45bc90bc2730431323 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/slices_test.py @@ -0,0 +1,59 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for slices module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.converters import slices +from tensorflow.contrib.autograph.core import converter_testing +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import list_ops +from tensorflow.python.platform import test + + +class SliceTest(converter_testing.TestCase): + + def test_index_access(self): + + def test_fn(l): + utils.set_element_type(l, dtypes.int32) + return l[1] + + node = self.parse_and_analyze( + test_fn, + { + 'utils': utils, + 'dtypes': dtypes + }, + include_type_analysis=True, + ) + node = slices.transform(node, self.ctx) + + with self.compiled(node, dtypes.int32) as result: + result.utils = utils + result.dtypes = dtypes + with self.test_session() as sess: + tl = list_ops.tensor_list_from_tensor( + [1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32)) + y = result.test_fn(tl) + self.assertEqual(2, sess.run(y)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/core/BUILD b/tensorflow/contrib/autograph/core/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..833f9dced81bd651244d281322c830bb1c88b259 --- /dev/null +++ b/tensorflow/contrib/autograph/core/BUILD @@ -0,0 +1,59 @@ +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "core", + srcs = [ + "config.py", + "converter.py", + "naming.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/pyct/static_analysis", + "//tensorflow/contrib/autograph/utils", + ], +) + +py_library( + name = "test_lib", + srcs = [ + "converter_testing.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":core", + "//tensorflow/contrib/autograph/operators", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/pyct/static_analysis", + "//tensorflow/contrib/autograph/utils", + "@gast_archive//:gast", + "@six_archive//:six", + ], +) + +py_test( + name = "naming_test", + srcs = ["naming_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":core", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/impl/config.py b/tensorflow/contrib/autograph/core/config.py similarity index 100% rename from tensorflow/contrib/autograph/impl/config.py rename to tensorflow/contrib/autograph/core/config.py diff --git a/tensorflow/contrib/autograph/core/converter.py b/tensorflow/contrib/autograph/core/converter.py new file mode 100644 index 0000000000000000000000000000000000000000..54e6aa0f3bbb9059e044861362407cb5050240b4 --- /dev/null +++ b/tensorflow/contrib/autograph/core/converter.py @@ -0,0 +1,210 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Converter construction support. + +This module contains a base class for all converters, as well as supporting +structures. These structures are referred to as contexts. + +The class hierarchy is as follows: + + + [extends] converter.Base + [extends] transformer.Base + [extends] gast.nodeTransformer + [uses] transfomer.SourceInfo + [uses] converter.EntityContext + [uses] converter.ProgramContext + [uses] transfomer.SourceInfo + +converter.Base is a specialization of transformer.Base for AutoGraph. It's a +very lightweight subclass that adds a `ctx` attribute holding the corresponding +EntityContext object (see below). Note that converters are not reusable, and +`visit` will raise an error if called more than once. + +converter.EntityContext contains mutable state associated with an entity that +the converter processes. + +converter.ProgramContext contains mutable state across related entities. For +example, when converting several functions that call one another, the +ProgramContext should be shared across these entities. + +Below is the overal flow at conversion: + + program_ctx = ProgramContext(, , ...) + while : + entity, source_info = + entity_ctx = EntityContext(program_ctx, source_info) + for : + converter = ConverterClass(entity_ctx) + + # May update entity_ctx and program_ctx + entity = converter.visit(entity) + + + +Note that pyct contains a small number of transformers used for static analysis. +These implement transformer.Base, rather than converter.Base, to avoid a +dependency on AutoGraph. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.contrib.autograph.core import config +from tensorflow.contrib.autograph.core import naming +from tensorflow.contrib.autograph.pyct import transformer + +# TODO(mdan): These contexts can be refactored into first class objects. +# For example, we could define Program and Entity abstractions that hold on +# to the actual entity and have conversion methods. + + +class ProgramContext(object): + """ProgramContext keeps track of converting function hierarchies. + + This object is mutable, and is updated during conversion. Not thread safe. + + Attributes: + recursive: bool, whether to recursively convert any functions that the + decorator function may call. + autograph_decorators: Tuple[Callable, ...], decorator functions that belong + to AutoGraph. These require special treatment. + dependency_cache: Dict[Any, ast.AST], the original entities mapped to their + converted AST + additional_imports: Set[Any], additional entities which for any reason + cannot be attached after loading and need to be explicitly imported + in the generated code + name_map: Dict[str, str], map of original entity name to the name of + their converted counterparts + autograph_module: Module, a reference to the autograph module. This + needs to be specified by the caller to avoid circular dependencies. + uncompiled_modules: Set[Tuple[str, ...]], with each tuple representing the + fully qualified name of a package containing functions that will not be + compiled. + required_imports: str, containing an import statement on each line. These + are all the imports necessary for the compiled code to run, in addition + to the closures of each entity, which are attached dynamically. + """ + + def __init__( + self, + recursive, + autograph_decorators, + partial_types, + autograph_module, + uncompiled_modules, + ): + self.recursive = recursive + self.autograph_decorators = autograph_decorators + self.partial_types = partial_types if partial_types else () + self.autograph_module = autograph_module + self.uncompiled_modules = uncompiled_modules + + # Required to output dependencies in discovery order, which should match + # the reverse dependency order. + self.dependency_cache = collections.OrderedDict() + self.additional_imports = set() + self.name_map = {} + + @property + def required_imports(self): + """Returns a block containing all imports required by the converted code.""" + # TODO(mdan): Check that these don't clobber one another. + return '\n'.join(config.COMPILED_IMPORT_STATEMENTS + + tuple(self.additional_imports)) + + def new_namer(self, namespace): + return naming.Namer(namespace, self.recursive, self.name_map, + self.partial_types) + + def update_name_map(self, namer): + """Updates renamed_calls based on the recent activity from the namer. + + Whenever we convert a new entity, any references to other entities are being + renamed to match their soon-to-be-converted counterparts. The namer keeps + track of these renames. When conversion is complete, we copy those renames + so that when those referenced entities are being converted, their new name + matches. + + Args: + namer: naming.Namer + + Raises: + ValueError: when an entity was renamed twice and to different names. + """ + # TODO(mdan): Have call_trees do this directly. + # This is done so indirectly, via the namer, for historic reasons. But + # now we can have the converter that does the rename record the new name + # as well and skip this step altogether. + for o, name in namer.renamed_calls.items(): + if o in self.name_map: + if self.name_map[o] != name: + raise ValueError( + 'Calls to %s were converted using multiple names (%s). This is ' + 'possible when an entity with one of these names already ' + 'existed. To fix, avoid using any of these names.' % + (o, (name, self.name_map[o]))) + else: + self.name_map[o] = name + + def add_to_cache(self, original_entity, converted_ast): + self.dependency_cache[original_entity] = converted_ast + + +class EntityContext(object): + """Tracks the conversion of a single entity. + + This object is mutable, and is updated during conversion. Not thread safe. + + Attributes: + namer: Namer + info: transformer.EntityInfo + program: ProgramContext + """ + + def __init__(self, namer, entity_info, program_ctx): + self.namer = namer + self.info = entity_info + self.program = program_ctx + + +class Base(transformer.Base): + """All converters should inherit from this class. + + Attributes: + ctx: EntityContext + """ + + def __init__(self, ctx): + super(Base, self).__init__(ctx.info) + self.ctx = ctx # Keeping this short because it's used frequently. + + self._used = False + self._ast_depth = 0 + + def visit(self, node): + if not self._ast_depth: + if self._used: + raise ValueError('converter objects cannot be reused') + self._used = True + + self._ast_depth += 1 + try: + return super(Base, self).visit(node) + finally: + self._ast_depth -= 1 diff --git a/tensorflow/contrib/autograph/converters/converter_test_base.py b/tensorflow/contrib/autograph/core/converter_testing.py similarity index 80% rename from tensorflow/contrib/autograph/converters/converter_test_base.py rename to tensorflow/contrib/autograph/core/converter_testing.py index 41c2e71702e7e3ee3811a2cbee27c8c988eb3a5c..0e46aacc1216d2dbd9d34ad0e72ca8251094bddc 100644 --- a/tensorflow/contrib/autograph/converters/converter_test_base.py +++ b/tensorflow/contrib/autograph/core/converter_testing.py @@ -23,17 +23,24 @@ import imp from tensorflow.contrib.autograph import operators from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.core import config +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import compiler -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import pretty_printer from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis import activity from tensorflow.contrib.autograph.pyct.static_analysis import live_values from tensorflow.contrib.autograph.pyct.static_analysis import type_info from tensorflow.python.platform import test +def imported_decorator(f): + return lambda a: f(a) + 1 + + +# TODO(mdan): We might be able to use the real namer here. class FakeNamer(object): """A fake namer that uses a global counter to generate unique names.""" @@ -114,23 +121,32 @@ class TestCase(test.TestCase): arg_types=None, include_type_analysis=True, owner_type=None, - recursive=True): + recursive=True, + autograph_decorators=()): node, source = parser.parse_entity(test_fn) - ctx = context.EntityContext( - namer=namer or FakeNamer(), + + if namer is None: + namer = FakeNamer() + program_ctx = converter.ProgramContext( + recursive=recursive, + autograph_decorators=autograph_decorators, + partial_types=None, + autograph_module=None, + uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) + entity_info = transformer.EntityInfo( source_code=source, - source_file=None, + source_file='', namespace=namespace, arg_values=None, arg_types=arg_types, - owner_type=owner_type, - recursive=recursive, - type_annotation_func=utils.set_element_type) + owner_type=owner_type) + ctx = converter.EntityContext(namer, entity_info, program_ctx) + node = qual_names.resolve(node) - node = activity.resolve(node, ctx) - node = live_values.resolve(node, ctx, {}) + node = activity.resolve(node, entity_info) + node = live_values.resolve(node, entity_info, {}) if include_type_analysis: - node = type_info.resolve(node, ctx) - node = live_values.resolve(node, ctx, {}) + node = type_info.resolve(node, entity_info) + node = live_values.resolve(node, entity_info, {}) self.ctx = ctx return node diff --git a/tensorflow/contrib/autograph/impl/naming.py b/tensorflow/contrib/autograph/core/naming.py similarity index 100% rename from tensorflow/contrib/autograph/impl/naming.py rename to tensorflow/contrib/autograph/core/naming.py diff --git a/tensorflow/contrib/autograph/impl/naming_test.py b/tensorflow/contrib/autograph/core/naming_test.py similarity index 98% rename from tensorflow/contrib/autograph/impl/naming_test.py rename to tensorflow/contrib/autograph/core/naming_test.py index 73fc0894655cb49e4f61bf8ca51995b06feb3072..d2bebd0478b1074e421b5da1427a0dbaf91b6c9f 100644 --- a/tensorflow/contrib/autograph/impl/naming_test.py +++ b/tensorflow/contrib/autograph/core/naming_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.impl import naming +from tensorflow.contrib.autograph.core import naming from tensorflow.python.platform import test diff --git a/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb index d62390494b78c415212ba91ac914cdfee324f971..0702273fac15da61a72d66d8344a5add32ad12a6 100644 --- a/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb +++ b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb @@ -570,7 +570,7 @@ " autograph.utils.set_element_type(numbers, tf.int32)\n", " for i in range(n):\n", " numbers.append(i)\n", - " return numbers.stack() # Stack the list so that it can be used as a Tensor\n", + " return autograph.stack(numbers) # Stack the list so that it can be used as a Tensor\n", "\n", "\n", "tf_f = autograph.to_graph(f)\n", @@ -648,7 +648,7 @@ " if not is_prime:\n", " continue\n", " primes.append(i)\n", - " all_primes = primes.stack()\n", + " all_primes = autograph.stack(primes)\n", "\n", " print('The prime numbers less than', n, 'are:')\n", " print(all_primes)\n", @@ -953,8 +953,9 @@ " train_accuracies.append(step_train_accuracy)\n", " test_accuracies.append(step_test_accuracy)\n", " i += 1\n", - " return (train_losses.stack(), test_losses.stack(), train_accuracies.stack(),\n", - " test_accuracies.stack())" + " return (autograph.stack(train_losses), autograph.stack(test_losses),\n", + " autograph.stack(train_accuracies),\n", + " autograph.stack(test_accuracies))" ], "execution_count": 0, "outputs": [] @@ -1236,7 +1237,7 @@ " cell_output, (state, output) = cell.call(ch, (state, output))\n", " hidden_outputs.append(cell_output)\n", " i += 1\n", - " hidden_outputs = hidden_outputs.stack()\n", + " hidden_outputs = autograph.stack(hidden_outputs)\n", " if training:\n", " hidden_outputs = tf.nn.dropout(hidden_outputs, 0.5)\n", " return hidden_outputs\n", diff --git a/tensorflow/contrib/autograph/impl/BUILD b/tensorflow/contrib/autograph/impl/BUILD index 91ae0b9b82c6f649c3c80b91ef894b2221cdc962..a5438592c30021eac7183b65ccc10c36d220bc57 100644 --- a/tensorflow/contrib/autograph/impl/BUILD +++ b/tensorflow/contrib/autograph/impl/BUILD @@ -18,19 +18,19 @@ py_library( name = "impl", srcs = [ "api.py", - "config.py", "conversion.py", - "naming.py", - "special_functions.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ "//tensorflow/contrib/autograph/converters", + "//tensorflow/contrib/autograph/core", "//tensorflow/contrib/autograph/operators", "//tensorflow/contrib/autograph/pyct", "//tensorflow/contrib/autograph/pyct/static_analysis", "//tensorflow/contrib/autograph/utils", + "//tensorflow/python:platform", + "//tensorflow/python:util", "@gast_archive//:gast", "@six_archive//:six", ], @@ -60,23 +60,3 @@ py_test( "@gast_archive//:gast", ], ) - -py_test( - name = "naming_test", - srcs = ["naming_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":impl", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "special_functions_test", - srcs = ["special_functions_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":impl", - "//tensorflow/python:client_testlib", - ], -) diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py index 24f87b2c14da4a3523f1e580d4362cbd3679a2cd..209e494ac2b313bdaec44b93915b02cef759cba3 100644 --- a/tensorflow/contrib/autograph/impl/api.py +++ b/tensorflow/contrib/autograph/impl/api.py @@ -27,11 +27,11 @@ import gast import six # pylint:enable=g-bad-import-order -from tensorflow.contrib.autograph.impl import config +from tensorflow.contrib.autograph.core import config +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.impl import conversion from tensorflow.contrib.autograph.pyct import compiler from tensorflow.contrib.autograph.pyct import inspect_utils -from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.utils import builtins from tensorflow.contrib.autograph.utils import py_func from tensorflow.python.platform import tf_logging as logging @@ -230,20 +230,20 @@ def to_graph(e, A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. """ - conversion_map = conversion.ConversionMap( + program_ctx = converter.ProgramContext( recursive=recursive, - nocompile_decorators=(convert, do_not_convert, converted_call), + autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, - api_module=tf_inspect.getmodule(to_graph)) - _, name, namespace = conversion.entity_to_graph(e, conversion_map, arg_values, + autograph_module=tf_inspect.getmodule(to_graph), + uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) + _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) module = gast.Module([]) - for import_line in config.COMPILED_IMPORT_STATEMENTS: - module.body.extend(parser.parse_str(import_line).body) - for dep in reversed(conversion_map.dependency_cache.values()): + for dep in reversed(program_ctx.dependency_cache.values()): module.body.append(dep) - compiled_node, compiled_src = compiler.ast_to_object(module) + compiled_node, compiled_src = compiler.ast_to_object( + module, source_prefix=program_ctx.required_imports) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? @@ -280,17 +280,16 @@ def to_code(e, Returns: String. """ - conversion_map = conversion.ConversionMap( + program_ctx = converter.ProgramContext( recursive=recursive, - nocompile_decorators=(convert, do_not_convert, converted_call), + autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, - api_module=tf_inspect.getmodule(to_graph)) - conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) + autograph_module=tf_inspect.getmodule(to_graph), + uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) + conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) - imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS) code = '\n'.join( compiler.ast_to_source(dep, indentation) - for dep in reversed(tuple( - six.itervalues(conversion_map.dependency_cache)))) + for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache)))) - return imports + '\n\n' + code + return program_ctx.required_imports + '\n\n' + code diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py index a7737b7f448131b1c54951efa719b481e1f4d0c9..ed9fbdd23002df4361de18f00e6b15738fcded52 100644 --- a/tensorflow/contrib/autograph/impl/api_test.py +++ b/tensorflow/contrib/autograph/impl/api_test.py @@ -21,8 +21,8 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.core import config from tensorflow.contrib.autograph.impl import api -from tensorflow.contrib.autograph.impl import config from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.utils import py_func from tensorflow.python.framework import constant_op diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py index 55a30dc127957b2a9caa053db843380c94bacfbf..776d19f672ebbd6b88985dda157434f2046d87e7 100644 --- a/tensorflow/contrib/autograph/impl/conversion.py +++ b/tensorflow/contrib/autograph/impl/conversion.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""High level conversion support.""" +"""Core conversion logic, serves as main point of access.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import imp import gast @@ -38,77 +37,23 @@ from tensorflow.contrib.autograph.converters import logical_expressions from tensorflow.contrib.autograph.converters import name_scopes from tensorflow.contrib.autograph.converters import side_effect_guards from tensorflow.contrib.autograph.converters import single_return -from tensorflow.contrib.autograph.impl import config -from tensorflow.contrib.autograph.impl import naming +from tensorflow.contrib.autograph.converters import slices +from tensorflow.contrib.autograph.core import config +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.pyct import ast_util -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import inspect_utils from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis import activity from tensorflow.contrib.autograph.pyct.static_analysis import live_values from tensorflow.contrib.autograph.pyct.static_analysis import type_info -from tensorflow.contrib.autograph.utils import type_hints from tensorflow.python.util import tf_inspect # TODO(mdan): Might we not need any renaming at all? -class ConversionMap(object): - """ConversionMap keeps track of converting function hierarchies. - - This object is mutable, and is updated as functions are converted. - - Attributes: - recursive: Whether to recursively convert any functions that the decorator - function may call. - nocompile_decorators: tuple of decorator functions that toggle compilation - off. - dependency_cache: dict[object]: ast; maps original entities to their - converted AST - additional_imports: set(object); additional entities which for any reason - cannot be attached after loading and need to be explicitly imported - in the generated code - name_map: dict[string]: string; maps original entities to the name of - their converted counterparts - api_module: A reference to the api module. The reference needs to be passed - to avoid circular dependencies. - """ - - # TODO(mdan): Rename to ConversionContext, and pull in additional flags. - - def __init__(self, recursive, nocompile_decorators, partial_types, - api_module): - self.recursive = recursive - self.nocompile_decorators = nocompile_decorators - self.partial_types = partial_types if partial_types else () - # Required to output dependencies in discovery order, which should match - # the reverse dependency order. - self.dependency_cache = collections.OrderedDict() - self.additional_imports = set() - self.name_map = {} - self.api_module = api_module - - def new_namer(self, namespace): - return naming.Namer(namespace, self.recursive, self.name_map, - self.partial_types) - - def update_name_map(self, namer): - for o, name in namer.renamed_calls.items(): - if o in self.name_map: - if self.name_map[o] != name: - raise ValueError( - 'Calls to %s were converted using multiple names (%s). This is ' - 'possible when an entity with one of these names already ' - 'existed. To fix, avoid using any of these names.') - else: - self.name_map[o] = name - - def add_to_cache(self, original_entity, converted_ast): - self.dependency_cache[original_entity] = converted_ast - - def is_whitelisted_for_graph(o): """Check whether an entity is whitelisted for use in graph mode. @@ -127,7 +72,7 @@ def is_whitelisted_for_graph(o): return False -def entity_to_graph(o, conversion_map, arg_values, arg_types): +def entity_to_graph(o, program_ctx, arg_values, arg_types): """Compile a Python entity into equivalent TensorFlow. The function will also recursively compile all the entities that `o` @@ -138,7 +83,7 @@ def entity_to_graph(o, conversion_map, arg_values, arg_types): Args: o: A Python entity. - conversion_map: A ConversionMap object. + program_ctx: A ProgramContext object. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function @@ -156,7 +101,7 @@ def entity_to_graph(o, conversion_map, arg_values, arg_types): ValueError: if the entity type is not supported. """ if tf_inspect.isclass(o): - node, name, ns = class_to_graph(o, conversion_map) + node, name, ns = class_to_graph(o, program_ctx) elif tf_inspect.isfunction(o): # TODO(mdan): This is not a reliable mechanism. # The most reliable way is to check the source code, the AST will contain @@ -166,36 +111,35 @@ def entity_to_graph(o, conversion_map, arg_values, arg_types): 'lambda functions are not yet supported; declare the function' ' using def instead: %s' % o) else: - node, name, ns = function_to_graph(o, conversion_map, arg_values, - arg_types) + node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) elif tf_inspect.ismethod(o): - node, name, ns = function_to_graph(o, conversion_map, arg_values, arg_types) + node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) else: raise ValueError( 'Entity "%s" has unsupported type "%s". Only functions and classes are ' 'supported for now.' % (o, type(o))) - conversion_map.add_to_cache(o, node) - if conversion_map.recursive: + program_ctx.add_to_cache(o, node) + if program_ctx.recursive: while True: candidate = None - for obj in conversion_map.name_map.keys(): - if obj not in conversion_map.dependency_cache: + for obj in program_ctx.name_map.keys(): + if obj not in program_ctx.dependency_cache: candidate = obj break if candidate is None: break if (hasattr(candidate, 'im_class') and - getattr(candidate, 'im_class') not in conversion_map.partial_types): + getattr(candidate, 'im_class') not in program_ctx.partial_types): # Class members are converted with their objects, unless they're # only converted partially. continue - entity_to_graph(candidate, conversion_map, {}, {}) + entity_to_graph(candidate, program_ctx, {}, {}) return node, name, ns -def class_to_graph(c, conversion_map): +def class_to_graph(c, program_ctx): """Specialization of `entity_to_graph` for classes.""" converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m) @@ -210,7 +154,7 @@ def class_to_graph(c, conversion_map): continue node, _, namespace = function_to_graph( m, - conversion_map=conversion_map, + program_ctx=program_ctx, arg_values={}, arg_types={'self': (c.__name__, c)}, owner_type=c) @@ -219,14 +163,14 @@ def class_to_graph(c, conversion_map): else: class_namespace.update(namespace) converted_members[m] = node - namer = conversion_map.new_namer(class_namespace) + namer = program_ctx.new_namer(class_namespace) class_name = namer.compiled_class_name(c.__name__, c) # TODO(mdan): This needs to be explained more thoroughly. # Process any base classes: if the sueprclass if of a whitelisted type, an # absolute import line is generated. Otherwise, it is marked for conversion # (as a side effect of the call to namer.compiled_class_name() followed by - # conversion_map.update_name_map(namer)). + # program_ctx.update_name_map(namer)). output_nodes = [] renames = {} bases = [] @@ -246,7 +190,7 @@ def class_to_graph(c, conversion_map): alias = namer.compiled_class_name(base.__name__, base) bases.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) - conversion_map.update_name_map(namer) + program_ctx.update_name_map(namer) # Generate the definition of the converted class. output_nodes.append( @@ -278,14 +222,14 @@ def _add_reserved_symbol(namespace, name, entity): ag_internal = None -def _add_self_references(namespace, api_module): +def _add_self_references(namespace, autograph_module): """Adds namespace references to the module that exposes the api itself.""" global ag_internal if ag_internal is None: # Craft a module that exposes parts of the external API as well as certain # internal modules. ag_internal = imp.new_module('autograph') - ag_internal.converted_call = api_module.converted_call + ag_internal.converted_call = autograph_module.converted_call ag_internal.utils = utils # TODO(mdan): Add safeguards against name clashes. # We don't want to create a submodule because we want the operators to be @@ -295,27 +239,24 @@ def _add_self_references(namespace, api_module): _add_reserved_symbol(namespace, 'ag__', ag_internal) -def function_to_graph(f, conversion_map, arg_values, arg_types, - owner_type=None): +def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] namespace = inspect_utils.getnamespace(f) - _add_self_references(namespace, conversion_map.api_module) - namer = conversion_map.new_namer(namespace) + _add_self_references(namespace, program_ctx.autograph_module) + namer = program_ctx.new_namer(namespace) - ctx = context.EntityContext( - namer=namer, + entity_info = transformer.EntityInfo( source_code=source, source_file='', namespace=namespace, arg_values=arg_values, arg_types=arg_types, - owner_type=owner_type, - recursive=conversion_map.recursive, - type_annotation_func=type_hints.set_element_type) - node, deps = node_to_graph(node, ctx, conversion_map.nocompile_decorators) + owner_type=owner_type) + context = converter.EntityContext(namer, entity_info, program_ctx) + node = node_to_graph(node, context) # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) @@ -325,29 +266,28 @@ def function_to_graph(f, conversion_map, arg_values, arg_types, raise NotImplementedError('Strange corner case. Send us offending code!') node.name = new_name - conversion_map.update_name_map(namer) + program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. - conversion_map.additional_imports.update(deps) return node, new_name, namespace -def _static_analysis_pass(node, ctx): +def _apply_transformer(node, context, converter_module): + # TODO(mdan): Clear static analysis here. node = qual_names.resolve(node) - node = activity.resolve(node, ctx, None) - node = live_values.resolve(node, ctx, config.PYTHON_LITERALS) - node = type_info.resolve(node, ctx) + node = activity.resolve(node, context.info, None) + node = live_values.resolve(node, context.info, config.PYTHON_LITERALS) + node = type_info.resolve(node, context.info) + node = converter_module.transform(node, context) return node -def node_to_graph(node, ctx, nocompile_decorators): +def node_to_graph(node, context): """Convert Python code to equivalent TF graph mode code. Args: - node: A Python AST node representing the code to convert. - ctx: An EntityContext object. - nocompile_decorators: A tuple containing decorators to be stripped from - functions during conversion. + node: AST, the code to convert. + context: converter.EntityContext Returns: A tuple (node, deps): @@ -357,53 +297,26 @@ def node_to_graph(node, ctx, nocompile_decorators): """ # TODO(mdan): Verify arguments for correctness. - # TODO(mdan): Factor out common elements. - # These include: - # * code move between blocks - # * visiting blocks in transformers - - # Certain steps, especially canonicalization, insert new symbols into the - # tree, which must be accounted. Although less efficient, it is most robust - # to re-run the analysis. - - node = _static_analysis_pass(node, ctx) - - # TODO(mdan): Clean this up. - # Some intermediate analyses are not required, and some comments got orphaned. - + node = _apply_transformer(node, context, ifexp) # Past this point, line numbers are no longer accurate so we ignore the # source. # TODO(mdan): Is it feasible to reconstruct intermediate source code? - ctx.source_code = None - node = ifexp.transform(node, ctx) - node, deps = decorators.transform(node, nocompile_decorators) - node = break_statements.transform(node, ctx) - node = _static_analysis_pass(node, ctx) - - node = asserts.transform(node, ctx) - + context.info.source_code = None + node = _apply_transformer(node, context, decorators) + node = _apply_transformer(node, context, break_statements) + node = _apply_transformer(node, context, asserts) # Note: sequencing continue canonicalization before for loop one avoids # dealing with the extra loop increment operation that the for # canonicalization creates. - node = continue_statements.transform(node, ctx) - ctx.namespace['len'] = len - - node = _static_analysis_pass(node, ctx) - node = single_return.transform(node, ctx) - - node = _static_analysis_pass(node, ctx) - node = lists.transform(node, ctx) - node = builtin_functions.transform(node, ctx) - - node = _static_analysis_pass(node, ctx) - node = call_trees.transform(node, ctx, config.DEFAULT_UNCOMPILED_MODULES, - nocompile_decorators) - node = control_flow.transform(node, ctx) - - # control_flow may create new symbols and change scopes. - node = _static_analysis_pass(node, ctx) - node = logical_expressions.transform(node, ctx) - node = side_effect_guards.transform(node, ctx) - node = name_scopes.transform(node, ctx) - - return node, deps + node = _apply_transformer(node, context, continue_statements) + context.info.namespace['len'] = len + node = _apply_transformer(node, context, single_return) + node = _apply_transformer(node, context, lists) + node = _apply_transformer(node, context, slices) + node = _apply_transformer(node, context, builtin_functions) + node = _apply_transformer(node, context, call_trees) + node = _apply_transformer(node, context, control_flow) + node = _apply_transformer(node, context, logical_expressions) + node = _apply_transformer(node, context, side_effect_guards) + node = _apply_transformer(node, context, name_scopes) + return node diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py index bc61498b5422f5e130bbfeef935d0a796b4f5922..f5279298afdcd406a9a6762e58367cea8ca63141 100644 --- a/tensorflow/contrib/autograph/impl/conversion_test.py +++ b/tensorflow/contrib/autograph/impl/conversion_test.py @@ -21,6 +21,8 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.core import config +from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.impl import api from tensorflow.contrib.autograph.impl import conversion from tensorflow.python.framework import constant_op @@ -30,8 +32,13 @@ from tensorflow.python.platform import test class ConversionTest(test.TestCase): - def _simple_conversion_map(self): - return conversion.ConversionMap(True, (), (), api) + def _simple_program_ctx(self): + return converter.ProgramContext( + recursive=True, + autograph_decorators=(), + partial_types=(), + autograph_module=api, + uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) def test_is_whitelisted_for_graph(self): @@ -44,16 +51,16 @@ class ConversionTest(test.TestCase): def test_entity_to_graph_unsupported_types(self): with self.assertRaises(ValueError): - conversion_map = self._simple_conversion_map() - conversion.entity_to_graph('dummy', conversion_map, None, None) + program_ctx = self._simple_program_ctx() + conversion.entity_to_graph('dummy', program_ctx, None, None) def test_entity_to_graph_callable(self): b = 2 def f(a): return a + b - conversion_map = self._simple_conversion_map() - ast, name, ns = conversion.entity_to_graph(f, conversion_map, None, None) + program_ctx = self._simple_program_ctx() + ast, name, ns = conversion.entity_to_graph(f, program_ctx, None, None) self.assertTrue(isinstance(ast, gast.FunctionDef), ast) self.assertEqual('tf__f', name) self.assertTrue(ns['b'] is b) @@ -66,18 +73,17 @@ class ConversionTest(test.TestCase): def f(a): return g(a) - conversion_map = self._simple_conversion_map() - conversion.entity_to_graph(f, conversion_map, None, None) + program_ctx = self._simple_program_ctx() + conversion.entity_to_graph(f, program_ctx, None, None) - self.assertTrue(f in conversion_map.dependency_cache) - self.assertTrue(g in conversion_map.dependency_cache) - self.assertEqual('tf__f', conversion_map.dependency_cache[f].name) + self.assertTrue(f in program_ctx.dependency_cache) + self.assertTrue(g in program_ctx.dependency_cache) + self.assertEqual('tf__f', program_ctx.dependency_cache[f].name) # need the extra .body[0] in order to step past the with tf.name_scope('f') # that is added automatically self.assertEqual( - 'tf__g', - conversion_map.dependency_cache[f].body[0].body[0].value.func.id) - self.assertEqual('tf__g', conversion_map.dependency_cache[g].name) + 'tf__g', program_ctx.dependency_cache[f].body[0].body[0].value.func.id) + self.assertEqual('tf__g', program_ctx.dependency_cache[g].name) def test_entity_to_graph_class_hierarchy(self): @@ -104,16 +110,15 @@ class ConversionTest(test.TestCase): def baz(self): return self.y - conversion_map = self._simple_conversion_map() - conversion.entity_to_graph(TestSubclass, conversion_map, None, None) + program_ctx = self._simple_program_ctx() + conversion.entity_to_graph(TestSubclass, program_ctx, None, None) - self.assertTrue(TestBase in conversion_map.dependency_cache) - self.assertTrue(TestSubclass in conversion_map.dependency_cache) + self.assertTrue(TestBase in program_ctx.dependency_cache) + self.assertTrue(TestSubclass in program_ctx.dependency_cache) self.assertEqual('TfTestBase', - conversion_map.dependency_cache[TestBase].body[-1].name) - self.assertEqual( - 'TfTestSubclass', - conversion_map.dependency_cache[TestSubclass].body[-1].name) + program_ctx.dependency_cache[TestBase].body[-1].name) + self.assertEqual('TfTestSubclass', + program_ctx.dependency_cache[TestSubclass].body[-1].name) def test_entity_to_graph_class_hierarchy_whitelisted(self): @@ -126,24 +131,23 @@ class ConversionTest(test.TestCase): def call(self, x): return 3 * x - conversion_map = self._simple_conversion_map() - conversion.entity_to_graph(TestSubclass, conversion_map, None, None) + program_ctx = self._simple_program_ctx() + conversion.entity_to_graph(TestSubclass, program_ctx, None, None) - self.assertTrue(TestSubclass in conversion_map.dependency_cache) - self.assertFalse(training.Model in conversion_map.dependency_cache) + self.assertTrue(TestSubclass in program_ctx.dependency_cache) + self.assertFalse(training.Model in program_ctx.dependency_cache) self.assertEqual( 'Model', - conversion_map.dependency_cache[TestSubclass].body[0].names[0].name) - self.assertEqual( - 'TfTestSubclass', - conversion_map.dependency_cache[TestSubclass].body[-1].name) + program_ctx.dependency_cache[TestSubclass].body[0].names[0].name) + self.assertEqual('TfTestSubclass', + program_ctx.dependency_cache[TestSubclass].body[-1].name) def test_entity_to_graph_lambda(self): f = lambda a: a with self.assertRaises(NotImplementedError): - conversion_map = self._simple_conversion_map() - conversion.entity_to_graph(f, conversion_map, None, None) + program_ctx = self._simple_program_ctx() + conversion.entity_to_graph(f, program_ctx, None, None) def test_ag_module_cached(self): def callee(): @@ -152,11 +156,11 @@ class ConversionTest(test.TestCase): def caller(a): return a() - conversion_map = self._simple_conversion_map() - _, _, callee_ns = conversion.entity_to_graph( - callee, conversion_map, None, None) - _, _, caller_ns = conversion.entity_to_graph( - caller, conversion_map, None, None) + program_ctx = self._simple_program_ctx() + _, _, callee_ns = conversion.entity_to_graph(callee, program_ctx, None, + None) + _, _, caller_ns = conversion.entity_to_graph(caller, program_ctx, None, + None) self.assertTrue(callee_ns['ag__'] is caller_ns['ag__']) diff --git a/tensorflow/contrib/autograph/lang/BUILD b/tensorflow/contrib/autograph/lang/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..77a2184e229003a3403cbe3bf116ad2570274a1b --- /dev/null +++ b/tensorflow/contrib/autograph/lang/BUILD @@ -0,0 +1,40 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "lang", + srcs = [ + "directives.py", + "special_functions.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + "//tensorflow/contrib/autograph/operators", + ], +) + +py_test( + name = "special_functions_test", + srcs = ["special_functions_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":lang", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/lang/directives.py b/tensorflow/contrib/autograph/lang/directives.py new file mode 100644 index 0000000000000000000000000000000000000000..aabe5d99394a0cb921196d1c6a6b2a9496ea7545 --- /dev/null +++ b/tensorflow/contrib/autograph/lang/directives.py @@ -0,0 +1,68 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Directives are special no-op functions that serve as compilation markers. + +They provide static information like type hints, compilation and TensorFlow +overrides. + +These serve as annotations in the compiled code, allowing the user some control +over the compilation process. They have no functional role at runtime. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +UNSPECIFIED = object() + + +def set_element_type(entity, dtype, shape=UNSPECIFIED): + """Indicates that the entity is expected hold items of specified type/shape. + + The staged TensorFlow ops will reflect and assert this data type. Ignored + otherwise. + + Args: + entity: The entity to annotate. + dtype: TensorFlow dtype value to assert for entity. + shape: Optional shape to assert for entity. + """ + del entity + del dtype + del shape + + +def set_loop_options( + parallel_iterations=UNSPECIFIED, + back_prop=UNSPECIFIED, + swap_memory=UNSPECIFIED, + maximum_iterations=UNSPECIFIED): + """Specifies additional arguments to be passed to the enclosing while_loop. + + The parameters apply to and only to the immediately enclosing loop. It only + has effect if the loop is staged as a TF while_loop; otherwise the parameters + have no effect. + + Args: + parallel_iterations: See tf.while_loop. + back_prop: See tf.while_loop. + swap_memory: See tf.while_loop. + maximum_iterations: See tf.while_loop. + """ + del parallel_iterations + del back_prop + del swap_memory + del maximum_iterations diff --git a/tensorflow/contrib/autograph/impl/special_functions.py b/tensorflow/contrib/autograph/lang/special_functions.py similarity index 62% rename from tensorflow/contrib/autograph/impl/special_functions.py rename to tensorflow/contrib/autograph/lang/special_functions.py index b7a8177c44c88217560fb7f72c77d3ac1aa0c9ec..11135295a7966bc5d693676fcc71fe43791f2e99 100644 --- a/tensorflow/contrib/autograph/impl/special_functions.py +++ b/tensorflow/contrib/autograph/lang/special_functions.py @@ -26,23 +26,34 @@ from __future__ import print_function from tensorflow.contrib.autograph.operators import data_structures -def stack(list_or_tensor, element_dtype=None): - """Stacks the input, if it admits the notion of stacking. No-op otherwise. +def stack(list_or_tensor, element_dtype=None, strict=True): + """Stacks the input, if it admits the notion of stacking. For example, a list of tensors can be stacked into a larger tensor. This function is similar to tf.stack, but it accepts non-lists and lists of non-tensors as arguments. In the latter case, the function does nothing. Args: - list_or_tensor: Any entity. - element_dtype: Optional dtype for the elements in the list. Required if the - input is stackable, and the list is untyped. + list_or_tensor: Any + element_dtype: tf.DType, optional dtypedtype for the elements in the list. + Required if the input is stackable, and the list is untyped. + strict: bool, if True an error is raised if the input is not stackable. + Otherwise the function is a no-op. Returns: - If the input is stackable, a new object representing the stacked inputs. - Otherwise it returns list_or_tensor unchanged. + Any, if the input is stackable, the result will be a tf.Tensor. Otherwise, + if strict=False, the result will be list_or_tensor. + + Raises: + ValueError: if strict=True and the input is not stackable. """ + if strict: + def raise_error(x): + raise ValueError('%s must be stackable when strict=True' % x) + original_call = raise_error + else: + original_call = lambda x: x return data_structures.list_stack( list_or_tensor, data_structures.ListStackOpts( - element_dtype=element_dtype, original_call=lambda x: x)) + element_dtype=element_dtype, original_call=original_call)) diff --git a/tensorflow/contrib/autograph/impl/special_functions_test.py b/tensorflow/contrib/autograph/lang/special_functions_test.py similarity index 81% rename from tensorflow/contrib/autograph/impl/special_functions_test.py rename to tensorflow/contrib/autograph/lang/special_functions_test.py index 9b52d2a59b5a3e3c92f11343197379c773ecc828..a49cb6407517b634e0f1259fccda03d4ed18e83f 100644 --- a/tensorflow/contrib/autograph/impl/special_functions_test.py +++ b/tensorflow/contrib/autograph/lang/special_functions_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.impl import special_functions +from tensorflow.contrib.autograph.lang import special_functions from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util @@ -29,14 +29,18 @@ from tensorflow.python.platform import test class SpecialFunctionsTest(test.TestCase): def test_basic(self): - self.assertEqual(special_functions.stack(1), 1) - self.assertListEqual(special_functions.stack([1, 2, 3]), [1, 2, 3]) + self.assertEqual(special_functions.stack(1, strict=False), 1) + self.assertListEqual( + special_functions.stack([1, 2, 3], strict=False), [1, 2, 3]) # TODO(mdan): This should probably forward to tf.stack. self.assertTrue( isinstance( special_functions.stack( [constant_op.constant(1), - constant_op.constant(2)]), list)) + constant_op.constant(2)], strict=False), list)) + + with self.assertRaises(ValueError): + special_functions.stack([1, 2, 3]) t = constant_op.constant([1.0, 2.0]) l = list_ops.tensor_list_from_tensor( diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD index 0c6ab65505ee03e19588adae73d3134399a34b65..332d5dab19e7ade1531b564fbdef2fa0dc2d09d5 100644 --- a/tensorflow/contrib/autograph/operators/BUILD +++ b/tensorflow/contrib/autograph/operators/BUILD @@ -28,7 +28,15 @@ py_library( visibility = ["//tensorflow:__subpackages__"], deps = [ "//tensorflow/contrib/autograph/utils", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:list_ops", "//tensorflow/python:tensor_array_ops", + "//tensorflow/python:tensor_util", + "//tensorflow/python:variables", "//tensorflow/python/data/ops:dataset_ops", ], ) diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py index 671c9ccc13eaa887522cfc248a6d56d7ab9719ca..988df70157170ed0a9ece33976e871e6f7693bbc 100644 --- a/tensorflow/contrib/autograph/operators/control_flow.py +++ b/tensorflow/contrib/autograph/operators/control_flow.py @@ -51,7 +51,7 @@ def for_stmt(iter_, extra_test, body, init_state): Args: iter_: The entity being iterated over. extra_test: Callable with the state as arguments, and boolean return type. - An additionnal loop condition. + An additional loop condition. body: Callable with the iterate and the state as arguments, and state as return type. The actual loop body. init_state: Tuple containing the initial state. diff --git a/tensorflow/contrib/autograph/pyct/BUILD b/tensorflow/contrib/autograph/pyct/BUILD index 989b821e53a5cefbe39095e669f9a9e0bec65b8a..8f09689fe9b33bec03dc8b5370633c3a953fa322 100644 --- a/tensorflow/contrib/autograph/pyct/BUILD +++ b/tensorflow/contrib/autograph/pyct/BUILD @@ -23,7 +23,6 @@ py_library( "anno.py", "ast_util.py", "compiler.py", - "context.py", "inspect_utils.py", "parser.py", "pretty_printer.py", @@ -38,6 +37,8 @@ py_library( "@gast_archive//:gast", "@six_archive//:six", "@termcolor_archive//:termcolor", + # TODO(mdan): Remove this dependency. + "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/autograph/pyct/anno.py b/tensorflow/contrib/autograph/pyct/anno.py index cc4a7edf02ed7556c9a552d8730e4c7875038c83..ae861627fd65cca057e7bf1af41424e605d4b7a1 100644 --- a/tensorflow/contrib/autograph/pyct/anno.py +++ b/tensorflow/contrib/autograph/pyct/anno.py @@ -46,8 +46,15 @@ class Basic(NoValue): '`name_map` allows renaming symbols.') -def getanno(node, key, field_name='___pyct_anno'): - return getattr(node, field_name)[key] +FAIL = object() + + +def getanno(node, key, default=FAIL, field_name='___pyct_anno'): + if (default is FAIL or + (hasattr(node, field_name) and (key in getattr(node, field_name)))): + return getattr(node, field_name)[key] + else: + return default def hasanno(node, key, field_name='___pyct_anno'): @@ -73,5 +80,9 @@ def delanno(node, key, field_name='___pyct_anno'): def copyanno(from_node, to_node, key, field_name='___pyct_anno'): - if hasanno(from_node, key, field_name): - setanno(to_node, key, getanno(from_node, key, field_name), field_name) + if hasanno(from_node, key, field_name=field_name): + setanno( + to_node, + key, + getanno(from_node, key, field_name=field_name), + field_name=field_name) diff --git a/tensorflow/contrib/autograph/pyct/anno_test.py b/tensorflow/contrib/autograph/pyct/anno_test.py index 1d4d9d119e0c45c4bf9dd4e5b8156766489a2e4d..f2c0c8cf05ca4b3671eb653ce56f6da61de54aee 100644 --- a/tensorflow/contrib/autograph/pyct/anno_test.py +++ b/tensorflow/contrib/autograph/pyct/anno_test.py @@ -38,12 +38,14 @@ class AnnoTest(test.TestCase): anno.setanno(node, 'foo', 3) self.assertTrue(anno.hasanno(node, 'foo')) - self.assertEqual(3, anno.getanno(node, 'foo')) + self.assertEqual(anno.getanno(node, 'foo'), 3) + self.assertEqual(anno.getanno(node, 'bar', default=7), 7) anno.delanno(node, 'foo') self.assertFalse(anno.hasanno(node, 'foo')) with self.assertRaises(AttributeError): anno.getanno(node, 'foo') + self.assertIsNone(anno.getanno(node, 'foo', default=None)) def test_copyanno(self): node_1 = ast.Name() diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..ca1441cf6f8bb034c95b37fcdd9e8158d1db2e39 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD @@ -0,0 +1,38 @@ +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "common_transformers", + srcs = [ + "anf.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/autograph/pyct", + "@gast_archive//:gast", + ], +) + +py_test( + name = "anf_test", + srcs = ["anf_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":common_transformers", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py new file mode 100644 index 0000000000000000000000000000000000000000..cc039986c219db1febfe610a5078e26eeb2d5a83 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py @@ -0,0 +1,57 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Conversion to A-normal form.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.pyct import transformer + + +class DummyGensym(object): + """A dumb gensym that suffixes a stem by sequential numbers from 1000.""" + + def __init__(self, entity_info): + del entity_info + # A proper implementation needs to account for: + # * entity_info.namespace + # * all the symbols defined in the AST + # * the symbols generated so far + self._idx = 0 + + def new_name(self, stem): + self._idx += 1 + return stem + '_' + str(1000 + self._idx) + + +class AnfTransformer(transformer.Base): + """Performs the actual conversion.""" + + # TODO(mdan): Link to a reference. + # TODO(mdan): Implement. + + def __init__(self, entity_info): + """Creates a transformer. + + Args: + entity_info: transformer.EntityInfo + """ + super(AnfTransformer, self).__init__(entity_info) + self._gensym = DummyGensym(entity_info) + + +def transform(node, entity_info): + return AnfTransformer(entity_info).visit(node) diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py new file mode 100644 index 0000000000000000000000000000000000000000..81983a5ecb7b8c6216285409f854e27b7154a08b --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for anf module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.pyct import compiler +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.common_transformers import anf +from tensorflow.python.platform import test + + +class AnfTransformerTest(test.TestCase): + + def _simple_source_info(self): + return transformer.EntityInfo( + source_code=None, + source_file=None, + namespace=None, + arg_values=None, + arg_types=None, + owner_type=None) + + def test_basic(self): + + def test_function(): + a = 0 + return a + + node, _ = parser.parse_entity(test_function) + node = anf.transform(node, self._simple_source_info()) + result, _ = compiler.ast_to_object(node) + + self.assertEqual(test_function(), result.test_function()) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/context.py b/tensorflow/contrib/autograph/pyct/context.py deleted file mode 100644 index b34015cfd2888f0dbeb6492b9e7335d561bf4763..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/autograph/pyct/context.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Conversion context containers.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - - -class EntityContext(object): - """Contains information about an entity, like source code. - - In general, objects of this class should be considered immutable. - - Attributes: - namer: Namer that matches the contract of all converters. - source_code: The entity's source code. - source_file: The entity's source file. - namespace: Dict[str->*], containing symbols visible to the entity - (excluding parameters). - arg_values: Dict[str->*], containing parameter values, if known. - arg_types: Dict[str->*], containing parameter types, if known. - owner_type: The surrounding class type of the function, if present. - """ - - # TODO(mdan): Remove the default and update tests. - def __init__(self, namer, source_code, source_file, namespace, arg_values, - arg_types, owner_type, recursive, type_annotation_func=None): - self.namer = namer - self.source_code = source_code - self.source_file = source_file - self.namespace = namespace - self.arg_values = {} if arg_values is None else arg_values - self.arg_types = {} if arg_types is None else arg_types - self.owner_type = owner_type - self.recursive = recursive - self.type_annotation_func = type_annotation_func diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD index 8064a967cd389e88d3febbeb21cac87b0fef9e18..bcf2dacec2062704805f1d72ec27a243159d13c1 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD +++ b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD @@ -27,6 +27,7 @@ py_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/contrib/autograph/pyct", + "//tensorflow/contrib/autograph/utils", "@gast_archive//:gast", ], ) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py index fdbd349af9d3325af114a7206d89617134278f14..bc22be0a270bbc9c361aea6d6d9c255ea51796e8 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py @@ -21,9 +21,9 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.qual_names import QN from tensorflow.contrib.autograph.pyct.static_analysis import activity from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno @@ -112,18 +112,16 @@ class ActivityAnalyzerTest(test.TestCase): def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn) - ctx = context.EntityContext( - namer=None, + entity_info = transformer.EntityInfo( source_code=source, source_file=None, namespace={}, arg_values=None, arg_types=None, - owner_type=None, - recursive=True) + owner_type=None) node = qual_names.resolve(node) - node = activity.resolve(node, ctx) - return node, ctx + node = activity.resolve(node, entity_info) + return node, entity_info def test_local_markers(self): diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py index ad97fdfa8e78d1fd4c38724612d83519c6609cce..39eca6e44441cc28e565d383759cc796d57d6438 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py @@ -276,9 +276,9 @@ class Forward(object): taken). """ - def __init__(self, label, context, transfer_fn=operator.or_): + def __init__(self, label, source_info, transfer_fn=operator.or_): self.transfer_fn = transfer_fn - self.context = context + self.source_info = source_info self.out_label = label + '_out' self.in_label = label + '_in' self.gen_label = label + '_gen' @@ -286,7 +286,7 @@ class Forward(object): # TODO(alexbw): see if we can simplify by visiting breadth-first def visit(self, node): - """Depth-first walking the CFG, applying dataflow information propagtion.""" + """Depth-first walking the CFG, applying dataflow information propagation.""" # node.value is None only for the exit CfgNode. if not node.value: return @@ -399,18 +399,18 @@ class Liveness(Backward): later in the program. """ - def __init__(self, context): - super(Liveness, self).__init__('live', context) + def __init__(self, source_info): + super(Liveness, self).__init__('live', source_info) def get_gen_kill(self, node, _): # A variable's parents are live if it is live # e.g. x is live if x.y is live. This means gen needs to return # all parents of a variable (if it's an Attribute or Subscript). # This doesn't apply to kill (e.g. del x.y doesn't affect liveness of x) - gen = activity.get_read(node.value, self.context) + gen = activity.get_read(node.value, self.source_info) gen = functools.reduce(lambda left, right: left | right.support_set, gen, gen) - kill = activity.get_updated(node.value, self.context) + kill = activity.get_updated(node.value, self.source_info) return gen, kill @@ -420,11 +420,11 @@ class ReachingDefinitions(Forward): Each statement is annotated with a set of (variable, definition) pairs. """ - def __init__(self, context): - super(ReachingDefinitions, self).__init__('definitions', context) + def __init__(self, source_info): + super(ReachingDefinitions, self).__init__('definitions', source_info) def get_gen_kill(self, node, incoming): - definitions = activity.get_updated(node.value, self.context) + definitions = activity.get_updated(node.value, self.source_info) gen = frozenset((id_, node.value) for id_ in definitions) kill = frozenset(def_ for def_ in incoming if def_[0] in definitions) return gen, kill @@ -437,9 +437,10 @@ class Defined(Forward): be defined at that point. """ - def __init__(self, context): - super(Defined, self).__init__('defined', context, transfer_fn=operator.and_) + def __init__(self, source_info): + super(Defined, self).__init__( + 'defined', source_info, transfer_fn=operator.and_) def get_gen_kill(self, node, _): - gen = activity.get_updated(node.value, self.context) + gen = activity.get_updated(node.value, self.source_info) return gen, frozenset() diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py index fc07fa3447b23c0595a5893329de8a2d7055ca15..428ebbedca85f9b94b4b1db0f3b36a334126196b 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py @@ -23,29 +23,26 @@ import functools import gast from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis import cfg from tensorflow.python.platform import test class CFGTest(test.TestCase): - def _parse_and_analyze(self, test_fn, namespace, arg_types=None): - arg_types = arg_types or {} + def _parse_and_analyze(self, test_fn): node, source = parser.parse_entity(test_fn) - ctx = context.EntityContext( - namer=None, + entity_info = transformer.EntityInfo( source_code=source, source_file=None, - namespace=namespace, + namespace={}, arg_values=None, - arg_types=arg_types, - owner_type=None, - recursive=True) + arg_types=None, + owner_type=None) node = qual_names.resolve(node) - return node, ctx + return node, entity_info def _check_anno_matches(self, node, anno_name, var_names): if isinstance(var_names, str): @@ -73,7 +70,7 @@ class CFGTest(test.TestCase): x = x return x - node, ctx = self._parse_and_analyze(f, {}) + node, ctx = self._parse_and_analyze(f) cfg.run_analyses(node, cfg.ReachingDefinitions(ctx)) body = node.body[0].body # Only the argument reaches the expression @@ -106,7 +103,7 @@ class CFGTest(test.TestCase): y = 2 # pylint: disable=unused-variable return x - node, ctx = self._parse_and_analyze(f, {}) + node, ctx = self._parse_and_analyze(f) cfg.run_analyses(node, cfg.Defined(ctx)) body = node.body[0].body # only x is for sure defined at the end @@ -116,7 +113,7 @@ class CFGTest(test.TestCase): self._check_anno_matches(if_body[0], 'defined_out', ('x', 'y')) def _get_live_annotated_fnbody(self, f): - node, ctx = self._parse_and_analyze(f, {}) + node, ctx = self._parse_and_analyze(f) cfg.run_analyses(node, cfg.Liveness(ctx)) body = node.body[0].body return body @@ -226,7 +223,7 @@ class CFGTest(test.TestCase): return g(x) - node, ctx = self._parse_and_analyze(f, {}) + node, ctx = self._parse_and_analyze(f) cfg.run_analyses(node, cfg.Defined(ctx)) body = node.body[0].body @@ -253,7 +250,7 @@ class CFGTest(test.TestCase): return g() # y is not defined here - node, ctx = self._parse_and_analyze(f, {}) + node, ctx = self._parse_and_analyze(f) cfg.run_analyses(node, cfg.Defined(ctx)) body = node.body[0].body self.assertEqual( @@ -282,7 +279,7 @@ class CFGTest(test.TestCase): return x, y for f in (for_orelse, while_orelse): - node, ctx = self._parse_and_analyze(f, {}) + node, ctx = self._parse_and_analyze(f) cfg.run_analyses(node, cfg.ReachingDefinitions(ctx)) body = node.body[0].body return_node = body[-1] diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py index 53ae15459097baff918432a493edd7360ebf209d..9ccb98f79adbe5410a7554548ee75ab95345962d 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py @@ -39,7 +39,7 @@ class LiveValueResolver(transformer.Base): def visit_ClassDef(self, node): self.generic_visit(node) - anno.setanno(node, 'live_val', self.context.namespace[node.name]) + anno.setanno(node, 'live_val', self.entity_info.namespace[node.name]) return node def visit_Name(self, node): @@ -55,8 +55,8 @@ class LiveValueResolver(transformer.Base): if not symbol_is_local and not symbol_is_param: if node.id in self.literals: anno.setanno(node, 'live_val', self.literals[node.id]) - elif node.id in self.context.namespace: - obj = self.context.namespace[node.id] + elif node.id in self.entity_info.namespace: + obj = self.entity_info.namespace[node.id] anno.setanno(node, 'live_val', obj) if hasattr(obj, '__name__'): anno.setanno(node, 'fqn', (obj.__name__,)) @@ -80,8 +80,8 @@ class LiveValueResolver(transformer.Base): # TODO(mdan): Use type annotations as fallback. if not symbol_is_modified: - if node.id in self.context.arg_values: - obj = self.context.arg_values[node.id] + if node.id in self.entity_info.arg_values: + obj = self.entity_info.arg_values[node.id] anno.setanno(node, 'live_val', obj) anno.setanno(node, 'fqn', (obj.__class__.__name__,)) return node diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py index 69e428bde109ed43c3cdda1a94970a832dc47852..38af79277779f77ffe31c2f6e26ae88f3e1a7ae9 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py @@ -21,9 +21,9 @@ from __future__ import print_function import six from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis import activity from tensorflow.contrib.autograph.pyct.static_analysis import live_values from tensorflow.contrib.autograph.pyct.static_analysis import type_info @@ -39,22 +39,19 @@ class LiveValuesResolverTest(test.TestCase): literals=None, arg_types=None): literals = literals or {} - arg_types = arg_types or {} node, source = parser.parse_entity(test_fn) - ctx = context.EntityContext( - namer=None, + entity_info = transformer.EntityInfo( source_code=source, source_file=None, namespace=namespace, arg_values=None, arg_types=arg_types, - owner_type=None, - recursive=True) + owner_type=None) node = qual_names.resolve(node) - node = activity.resolve(node, ctx) - node = live_values.resolve(node, ctx, literals) - node = type_info.resolve(node, ctx) - node = live_values.resolve(node, ctx, literals) + node = activity.resolve(node, entity_info) + node = live_values.resolve(node, entity_info, literals) + node = type_info.resolve(node, entity_info) + node = live_values.resolve(node, entity_info, literals) return node def test_literals(self): diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py index d6555dc7e0b3d49b3befa7326b28387509c83006..a229c288a83e516fc02f3af8df2046c5365e569c 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py @@ -17,8 +17,8 @@ This analyzer uses known live values to further infer object types. This may include for instance constructed objects and object member functions. -In addition, the analyzer will also process annotations for TF (staged) type -annotations. +In addition, the analyzer also handles user annotations made in the code (for +example, the autograph.set_element_type function). Requires annotations generated by LiveValuesResolver. """ @@ -43,7 +43,9 @@ from __future__ import print_function import gast +from tensorflow.contrib.autograph import utils from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.util import tf_inspect @@ -51,6 +53,7 @@ from tensorflow.python.util import tf_inspect # TODO(mdan): Remove the duplication between this and activity.py. # In particular, the symbol definitions we track here could as well be tracked # there because they follow the same rules for visibility. +# TODO(mdan): Use a CFG based Defined analysis instead. class Scope(object): """Tracks symbol value references. @@ -134,37 +137,40 @@ class TypeInfoResolver(transformer.Base): node.orelse = self._visit_block(node.orelse) return node - def _process_function_arg(self, arg_name): - str_name = str(arg_name) - type_holder = arg_name.ast() - self.scope.setval(arg_name, type_holder) - if len(self.enclosing_entities) == 1 and str_name in self.context.arg_types: + def _process_function_arg(self, arg_node): + qn = anno.getanno(arg_node, anno.Basic.QN) + arg_name = str(qn) + self.scope.setval(qn, arg_node) + if (len(self.enclosing_entities) == 1 and + arg_name in self.entity_info.arg_types): # Forge a node to hold the type information, so that method calls on # it can resolve the type. - type_string, type_obj = self.context.arg_types[str_name] - anno.setanno(type_holder, 'type', type_obj) - anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.'))) + type_string, type_obj = self.entity_info.arg_types[arg_name] + anno.setanno(arg_node, 'type', type_obj) + anno.setanno(arg_node, 'type_fqn', tuple(type_string.split('.'))) def visit_arg(self, node): - self._process_function_arg(anno.getanno(node.arg, anno.Basic.QN)) + self._process_function_arg(node.arg) return node def visit_Name(self, node): self.generic_visit(node) - qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Param): - self._process_function_arg(qn) - elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn): - # E.g. if we had - # a = b - # then for future references to `a` we should have definition = `b` - definition = self.scope.getval(qn) - if anno.hasanno(definition, 'type'): - anno.setanno(node, 'type', anno.getanno(definition, 'type')) - anno.setanno(node, 'type_fqn', anno.getanno(definition, 'type_fqn')) - if anno.hasanno(definition, 'element_type'): - anno.setanno(node, 'element_type', - anno.getanno(definition, 'element_type')) + self._process_function_arg(node) + elif isinstance(node.ctx, gast.Load): + qn = anno.getanno(node, anno.Basic.QN) + if self.scope.hasval(qn): + # E.g. if we had + # a = b + # then for future references to `a` we should have definition = `b` + definition = self.scope.getval(qn) + anno.copyanno(definition, node, 'type') + anno.copyanno(definition, node, 'type_fqn') + anno.setanno(node, 'definition', definition) + + # TODO(mdan): Remove this when the directives module is in. + anno.copyanno(definition, node, 'element_type') + anno.copyanno(definition, node, 'element_shape') return node def _process_variable_assignment(self, target, value): @@ -204,30 +210,27 @@ class TypeInfoResolver(transformer.Base): node.targets, node.value, self._process_variable_assignment) return node + # TODO(mdan): Remove as soon as the new directives module is ready. def visit_Call(self, node): if anno.hasanno(node.func, 'live_val'): # Symbols targeted by the "set_type" marker function are assigned the data # type that it specified. - if (anno.getanno(node.func, 'live_val') is - self.context.type_annotation_func): + if anno.getanno(node.func, 'live_val') is utils.set_element_type: - if len(node.args) != 2: - raise ValueError('"%s" must have exactly two parameters' + if len(node.args) < 2 or len(node.args) > 3: + raise ValueError('"%s" must have either two or three parameters' % self.context.type_annotation_func) - target_arg, type_arg = node.args - if not anno.hasanno(target_arg, anno.Basic.QN): - raise ValueError('the first argument of "%s" must by a symbol' - % self.context.type_annotation_func) - if isinstance(type_arg, gast.Str): - element_type = type_arg.s - elif isinstance(type_arg, gast.Num): - element_type = type_arg.n + if len(node.args) == 2: + target_arg, type_arg = node.args + shape_arg = parser.parse_expression('None') else: - if not anno.hasanno(type_arg, 'live_val'): - raise ValueError( - 'the second argument of "%s" must be statically resolvable' % - self.context.type_annotation_func) - element_type = anno.getanno(type_arg, 'live_val') + target_arg, type_arg, shape_arg = node.args + if not anno.hasanno(target_arg, anno.Basic.QN): + raise ValueError('the first argument of "%s" must by a symbol' % + utils.set_element_type) + # TODO(mdan): This is vulnerable to symbol renaming. + element_type = type_arg + element_shape = shape_arg target_symbol = anno.getanno(target_arg, anno.Basic.QN) # Find the definition of this symbol and annotate it with the given @@ -235,7 +238,9 @@ class TypeInfoResolver(transformer.Base): # to receive the same type annotation. definition = self.scope.getval(target_symbol) anno.setanno(node, 'element_type', element_type) + anno.setanno(node, 'element_shape', element_shape) anno.setanno(definition, 'element_type', element_type) + anno.setanno(definition, 'element_shape', element_shape) # TODO(mdan): Should we update references between definition and here? return self.generic_visit(node) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py index 95cbf5ca79a5045f5e050b735390dcfb668b5bb2..32b1148ab21809514bc09a31e26f0219017bd088 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py @@ -18,11 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph import utils from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis import activity from tensorflow.contrib.autograph.pyct.static_analysis import live_values from tensorflow.contrib.autograph.pyct.static_analysis import type_info @@ -62,21 +61,18 @@ class TypeInfoResolverTest(test.TestCase): namespace, arg_types=None): node, source = parser.parse_entity(test_fn) - ctx = context.EntityContext( - namer=None, + entity_info = transformer.EntityInfo( source_code=source, source_file=None, namespace=namespace, arg_values=None, arg_types=arg_types, - owner_type=None, - recursive=True, - type_annotation_func=utils.set_element_type) + owner_type=None) node = qual_names.resolve(node) - node = activity.resolve(node, ctx) - node = live_values.resolve(node, ctx, {}) - node = type_info.resolve(node, ctx) - node = live_values.resolve(node, ctx, {}) + node = activity.resolve(node, entity_info) + node = live_values.resolve(node, entity_info, {}) + node = type_info.resolve(node, entity_info) + node = live_values.resolve(node, entity_info, {}) return node def test_constructor_detection(self): @@ -147,7 +143,7 @@ class TypeInfoResolverTest(test.TestCase): opt.minimize(0) node = self._parse_and_analyze( - test_fn, {'training': training}, + test_fn, {}, arg_types={ 'opt': (training.GradientDescentOptimizer.__name__, training.GradientDescentOptimizer) @@ -180,35 +176,6 @@ class TypeInfoResolverTest(test.TestCase): method_call = node.body[0].body[1].value.func self.assertFalse(anno.hasanno(method_call, 'live_val')) - def test_type_annotation(self): - - class Foo(object): - pass - - def test_fn(): - f = [] - f = utils.set_element_type(f, Foo) - return f - - node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils}) - f_def = node.body[0].body[0].value - self.assertEqual(anno.getanno(f_def, 'element_type'), Foo) - f_ref = node.body[0].body[1].value - self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo) - - def test_type_annotation_args(self): - - class Foo(object): - pass - - def test_fn(f): - utils.set_element_type(f, Foo) - return f - - node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils}) - f_ref = node.body[0].body[1].value - self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo) - def test_nested_unpacking(self): class Foo(object): @@ -223,32 +190,13 @@ class TypeInfoResolverTest(test.TestCase): node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'Bar': Bar}) a, b, c = node.body[0].body[1].value.elts - self.assertEquals(Foo, anno.getanno(a, 'type')) - self.assertEquals(Bar, anno.getanno(b, 'type')) - self.assertEquals(Foo, anno.getanno(c, 'type')) + self.assertEquals(anno.getanno(a, 'type'), Foo) + self.assertEquals(anno.getanno(b, 'type'), Bar) + self.assertEquals(anno.getanno(c, 'type'), Foo) self.assertFalse(anno.hasanno(a, 'live_val')) self.assertFalse(anno.hasanno(b, 'live_val')) self.assertFalse(anno.hasanno(c, 'live_val')) - def test_inner_scope(self): - - def test_fn(): - a = [] - utils.set_element_type(a, 1) - for _ in a: - b = [] - utils.set_element_type(b, 2) - return a, b - - node = self._parse_and_analyze(test_fn, {'utils': utils}) - a, b = node.body[0].body[2].body[2].value.elts - self.assertEquals(1, anno.getanno(a, 'element_type')) - self.assertEquals(2, anno.getanno(b, 'element_type')) - self.assertFalse(anno.hasanno(a, 'type')) - self.assertFalse(anno.hasanno(b, 'type')) - self.assertFalse(anno.hasanno(a, 'live_val')) - self.assertFalse(anno.hasanno(b, 'live_val')) - if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/pyct/templates.py b/tensorflow/contrib/autograph/pyct/templates.py index baf7923fff7c786c1abd05e11fa6ffdb8c8f0912..9c479ebc2fa83d27dc363ae306daedb556734a1f 100644 --- a/tensorflow/contrib/autograph/pyct/templates.py +++ b/tensorflow/contrib/autograph/pyct/templates.py @@ -239,8 +239,13 @@ def replace_as_expression(template, **replacements): raise ValueError( 'single expression expected; for more general templates use replace') node = replacement[0] - if not isinstance(node, gast.Expr): - raise ValueError( - 'the template is expected to generate an expression node; instead ' - 'found %s' % node) - return node.value + node = qual_names.resolve(node) + + if isinstance(node, gast.Expr): + return node.value + elif isinstance(node, gast.Name): + return node + + raise ValueError( + 'the template is expected to generate an expression or a name node;' + ' instead found %s' % node) diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/contrib/autograph/pyct/transformer.py index 60bca8b38dcf62b4e997379d075cfc45511a894f..76558118308c31a2c1a770cad814e96abd6a6063 100644 --- a/tensorflow/contrib/autograph/pyct/transformer.py +++ b/tensorflow/contrib/autograph/pyct/transformer.py @@ -32,15 +32,40 @@ class AutographParseError(SyntaxError): pass -def try_ast_to_source(node): - try: - return compiler.ast_to_source(node) - except AssertionError: - return '' +# TODO(mdan): Use namedtuple. +class EntityInfo(object): + """Contains information about a Python entity. Immutable. + + Examples of entities include functions and classes. + + Attributes: + source_code: The entity's source code. + source_file: The entity's source file. + namespace: Dict[str, ], containing symbols visible to the entity + (excluding parameters). + arg_values: dict[str->*], containing parameter values, if known. + arg_types: dict[str->*], containing parameter types, if known. + owner_type: The surrounding class type of the function, if present. + """ + + # TODO(mdan): Remove the default and update tests. + def __init__(self, source_code, source_file, namespace, arg_values, arg_types, + owner_type): + self.source_code = source_code + self.source_file = source_file + self.namespace = namespace + self.arg_values = {} if arg_values is None else arg_values + self.arg_types = {} if arg_types is None else arg_types + self.owner_type = owner_type class Base(gast.NodeTransformer): - """Base class for specialized transformers. + """Base class for general-purpose code transformers transformers. + + This is an extension of ast.NodeTransformer that provides a few additional + functions, like state tracking within the scope of arbitrary node, helpers + for processing code blocks, debugging, mapping of transformed code to + original code, and others. Scope-local state tracking: to keep state across nodes, at the level of (possibly nested) scopes, use enter/exit_local_scope and set/get_local. @@ -48,15 +73,17 @@ class Base(gast.NodeTransformer): when they are not properly paired. """ - def __init__(self, context): + # TODO(mdan): Document all extra features. + + def __init__(self, entity_info): """Initialize the transformer. Subclasses should call this. Args: - context: An EntityContext. + entity_info: An EntityInfo object. """ self._lineno = 0 self._col_offset = 0 - self.context = context + self.entity_info = entity_info self._enclosing_entities = [] # A stack that allows keeping mutable, scope-local state where scopes may be @@ -191,7 +218,7 @@ class Base(gast.NodeTransformer): # TODO(mdan): Once we have error tracing, we may be able to just go to SSA. def apply_to_single_assignments(self, targets, values, apply_fn): - """Applies a fuction to each individual assignment. + """Applies a function to each individual assignment. This function can process a possibly-unpacked (e.g. a, b = c, d) assignment. It tries to break down the unpacking if possible. In effect, it has the same @@ -219,7 +246,7 @@ class Base(gast.NodeTransformer): targets field of an ast.Assign node. values: an AST node. apply_fn: a function of a single argument, which will be called with the - respective nodes of each single assignment. The signaure is + respective nodes of each single assignment. The signature is apply_fn(target, value), no return value. """ if not isinstance(targets, (list, tuple)): @@ -237,9 +264,15 @@ class Base(gast.NodeTransformer): # TODO(mdan): Look into allowing to rewrite the AST here. apply_fn(target, values) + def _get_source(self, node): + try: + return compiler.ast_to_source(node) + except AssertionError: + return '' + def visit(self, node): - source_code = self.context.source_code - source_file = self.context.source_file + source_code = self.entity_info.source_code + source_file = self.entity_info.source_file did_enter_function = False local_scope_size_at_entry = len(self._local_scope_state) @@ -275,7 +308,7 @@ class Base(gast.NodeTransformer): except (ValueError, AttributeError, KeyError, NotImplementedError) as e: msg = '%s: %s\nOffending source:\n%s\n\nOccurred at node:\n%s' % ( - e.__class__.__name__, str(e), try_ast_to_source(node), + e.__class__.__name__, str(e), self._get_source(node), pretty_printer.fmt(node, color=False)) if source_code: line = source_code.splitlines()[self._lineno - 1] diff --git a/tensorflow/contrib/autograph/pyct/transformer_test.py b/tensorflow/contrib/autograph/pyct/transformer_test.py index f110e79605945e908e8a49112cf758ec29fa1b11..baf04653ae862b0159fb50a1c67fa675ceb74b9a 100644 --- a/tensorflow/contrib/autograph/pyct/transformer_test.py +++ b/tensorflow/contrib/autograph/pyct/transformer_test.py @@ -21,7 +21,6 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.platform import test @@ -29,16 +28,14 @@ from tensorflow.python.platform import test class TransformerTest(test.TestCase): - def _context_for_testing(self): - return context.EntityContext( - namer=None, + def _simple_source_info(self): + return transformer.EntityInfo( source_code=None, source_file=None, namespace=None, arg_values=None, arg_types=None, - owner_type=None, - recursive=False) + owner_type=None) def test_entity_scope_tracking(self): @@ -55,7 +52,7 @@ class TransformerTest(test.TestCase): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) - tr = TestTransformer(self._context_for_testing()) + tr = TestTransformer(self._simple_source_info()) def test_function(): a = 0 @@ -118,7 +115,7 @@ class TransformerTest(test.TestCase): def visit_For(self, node): return self._annotate_result(node) - tr = TestTransformer(self._context_for_testing()) + tr = TestTransformer(self._simple_source_info()) def test_function(a): """Docstring.""" @@ -157,7 +154,7 @@ class TransformerTest(test.TestCase): self.exit_local_scope() return node - tr = TestTransformer(self._context_for_testing()) + tr = TestTransformer(self._simple_source_info()) def no_exit(a): if a > 0: @@ -196,7 +193,7 @@ class TransformerTest(test.TestCase): z = y return z - tr = TestTransformer(self._context_for_testing()) + tr = TestTransformer(self._simple_source_info()) node, _ = parser.parse_entity(test_function) node = tr.visit(node) diff --git a/tensorflow/contrib/batching/__init__.py b/tensorflow/contrib/batching/__init__.py index 44fa5f42a73bfb1bf008f6f4eafd14913c88dcfa..1e503a097a7b72d9244b0a1cf57747c4b4122c81 100644 --- a/tensorflow/contrib/batching/__init__.py +++ b/tensorflow/contrib/batching/__init__.py @@ -14,6 +14,7 @@ # ============================================================================== """Ops and modules related to batch. +@@batch_function_v1 @@batch_function """ from __future__ import absolute_import diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py index 921d6917a4e478c3e60771fdc3ae99febc33d2e3..47b80bdf4ad88ebce3603a14ea2aa3cbe5bd345f 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import gen_batch_ops # go/tf-wildcard-import @@ -83,6 +84,70 @@ def batch_function(num_batch_threads, SparseTensor is not supported. The return value of the decorated function must be a Tensor or a list/tuple of Tensors. + Args: + num_batch_threads: Number of scheduling threads for processing batches + of work. Determines the number of batches processed in parallel. + max_batch_size: Batch sizes will never be bigger than this. + batch_timeout_micros: Maximum number of microseconds to wait before + outputting an incomplete batch. + allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, + does nothing. Otherwise, supplies a list of batch sizes, causing the op + to pad batches up to one of those sizes. The entries must increase + monotonically, and the final entry must equal max_batch_size. + grad_timeout_micros: The timeout to use for the gradient. See the + documentation of the unbatch op for more details. Defaults to 60s. + unbatch_timeout_micros: The timeout to use for unbatching. See the + documentation of the unbatch op for more details. Defaults to 60s. + max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10. + + Returns: + The decorated function will return the unbatched computation output Tensors. + """ + + def decorator(fn): # pylint: disable=missing-docstring + + def decorated(*args): # pylint: disable=missing-docstring + types = [arg.dtype for arg in args] + + @function.Defun(*types) + def computation(*computation_args): + return fn(*computation_args) + + with ops.name_scope("batch") as name: + for a in args: + if not isinstance(a, ops.Tensor): + raise ValueError("All arguments to functions decorated with " + "`batch_function` are supposed to be Tensors; " + "found %s" % repr(a)) + return gen_batch_ops.batch_function( + num_batch_threads=num_batch_threads, + max_batch_size=max_batch_size, + batch_timeout_micros=batch_timeout_micros, + allowed_batch_sizes=allowed_batch_sizes, + max_enqueued_batches=max_enqueued_batches, + shared_name=name, + f=computation, + in_tensors=list(args), + captured_tensors=computation.captured_inputs, + Tout=[o.type for o in computation.definition.signature.output_arg]) + + return decorated + + return decorator + + +def batch_function_v1(num_batch_threads, + max_batch_size, + batch_timeout_micros, + allowed_batch_sizes=None, + grad_timeout_micros=60 * 1000 * 1000, + unbatch_timeout_micros=60 * 1000 * 1000, + max_enqueued_batches=10): + """Batches the computation done by the decorated function. + + This is the older version of batch_function(). Please use the former instead + of this. + Args: num_batch_threads: Number of scheduling threads for processing batches of work. Determines the number of batches processed in parallel. diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py index ea8339334f9b5e58a35dc9edf314a220e4c9868c..78468145469df216344bc00f116add250dc51dd3 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py @@ -188,12 +188,62 @@ class BatchOpsTest(test.TestCase): self.assertEqual(thread_results[0], [2]) self.assertEqual(main_results[0], [3]) + def testBasicUnbatchV1Decorated(self): + """Tests that the batch_function_v1 decorator works.""" + with self.test_session() as sess: + @batch_ops.batch_function_v1(1, 10, 100000) + def computation(in_t): + return in_t + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + result = computation(inp) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + def testBasicUnbatchDecorated(self): """Tests that the batch_function decorator works.""" with self.test_session() as sess: + # TODO(apassos): Removing this line causes test flakiness! Ideally should + # be investigated. + default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable + @batch_ops.batch_function(1, 10, 100000) def computation(in_t): return in_t + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + result = computation(inp) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + + def testBatchDecoratedWithCapturedInput(self): + """Tests that the batch_function decorator works.""" + with self.test_session() as sess: + captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) + captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) + + @batch_ops.batch_function(1, 10, 100000) + def computation(in_t): + return in_t + captured_inp0 - captured_inp1 + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) result = computation(inp) thread_results = [] diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py index 5770bcdd706723394bb06196d24aeb32b8b8491a..68fa415eeaf1d1ae7c2ecf1be1c300eddbfa4e69 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Monte Carlo integration and helpers. - -See the @{$python/contrib.bayesflow.monte_carlo} guide. -""" +"""Monte Carlo integration and helpers.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py index 758754feac31f1d2cf10e69d7a9a6d288931c900..911d87fa10570382ee5f03edfc1bfd1d116c8360 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -232,7 +232,13 @@ def _dnn_tree_combined_model_fn(features, return update_op if predict_with_tree_only: - tree_train_logits = tree_logits + if mode == model_fn.ModeKeys.TRAIN or mode == model_fn.ModeKeys.PREDICT: + tree_train_logits = tree_logits + else: + tree_train_logits = control_flow_ops.cond( + global_step > dnn_steps_to_train, + lambda: tree_logits, + lambda: dnn_logits) else: tree_train_logits = dnn_logits + tree_logits diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index 89d0d611d2905492cec09e033b8cbc238ec7fac6..9c36c302210185bc390751a0229a61f2f8cd91b8 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -41,7 +41,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): feature_engineering_fn=None, logits_modifier_function=None, center_bias=True, - use_core_libs=False): + use_core_libs=False, + output_leaf_index=False): """Initializes a GradientBoostedDecisionTreeClassifier estimator instance. Args: @@ -66,6 +67,16 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): the bias. use_core_libs: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is + [batch_size, num_trees]. + For example, + result_iter = classifier.predict(...) + for result_dict in result_iter: + # access leaf index list by result_dict["leaf_index"] + # which contains one leaf index per tree + Raises: ValueError: If learner_config is not valid. """ @@ -74,7 +85,9 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): # supports second order derivative. def loss_fn(labels, logits, weights=None): result = losses.per_example_maxent_loss( - labels=labels, logits=logits, weights=weights, + labels=labels, + logits=logits, + weights=weights, num_classes=n_classes) return math_ops.reduce_mean(result[0]) else: @@ -102,6 +115,7 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): 'center_bias': center_bias, 'logits_modifier_function': logits_modifier_function, 'use_core_libs': use_core_libs, + 'output_leaf_index': output_leaf_index, }, model_dir=model_dir, config=config, @@ -124,7 +138,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): feature_engineering_fn=None, logits_modifier_function=None, center_bias=True, - use_core_libs=False): + use_core_libs=False, + output_leaf_index=False): """Initializes a GradientBoostedDecisionTreeRegressor estimator instance. Args: @@ -151,6 +166,13 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): the bias. use_core_libs: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. For example, + result_dict = classifier.predict(...) + for example_prediction_result in result_dict: + # access leaf index list by example_prediction_result["leaf_index"] + # which contains one leaf index per tree """ head = head_lib.regression_head( label_name=label_name, @@ -173,6 +195,7 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): 'logits_modifier_function': logits_modifier_function, 'center_bias': center_bias, 'use_core_libs': use_core_libs, + 'output_leaf_index': False, }, model_dir=model_dir, config=config, @@ -197,7 +220,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): feature_engineering_fn=None, logits_modifier_function=None, center_bias=True, - use_core_libs=False): + use_core_libs=False, + output_leaf_index=False): """Initializes a GradientBoostedDecisionTreeEstimator estimator instance. Args: @@ -220,6 +244,13 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): the bias. use_core_libs: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. For example, + result_dict = classifier.predict(...) + for example_prediction_result in result_dict: + # access leaf index list by example_prediction_result["leaf_index"] + # which contains one leaf index per tree """ super(GradientBoostedDecisionTreeEstimator, self).__init__( model_fn=model.model_builder, @@ -233,6 +264,7 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): 'logits_modifier_function': logits_modifier_function, 'center_bias': center_bias, 'use_core_libs': use_core_libs, + 'output_leaf_index': False, }, model_dir=model_dir, config=config, diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py index 0d58317bd59331cfcde0e12aeb3a3a03fc45d89b..75ef1b050028b6462b255827c06e836e5c481844 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -68,6 +68,28 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): classifier.evaluate(input_fn=_eval_input_fn, steps=1) classifier.export(self._export_dir_base) + def testThatLeafIndexIsInPredictions(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[contrib_feature_column.real_valued_column("x")], + output_leaf_index=True) + + classifier.fit(input_fn=_train_input_fn, steps=15) + result_iter = classifier.predict(input_fn=_eval_input_fn) + for prediction_dict in result_iter: + self.assertTrue("leaf_index" in prediction_dict) + self.assertTrue("logits" in prediction_dict) + def testFitAndEvaluateDontThrowExceptionWithCoreForEstimator(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index 15ab6d814522ab1dee58dcd71246354fc4d8a483..1ee891198939e53fc5913104b2c2e65dc977823f 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -63,6 +63,8 @@ def model_builder(features, labels, mode, params, config): num_trees = params["num_trees"] use_core_libs = params["use_core_libs"] logits_modifier_function = params["logits_modifier_function"] + output_leaf_index = params["output_leaf_index"] + if features is None: raise ValueError("At least one feature must be specified.") @@ -96,7 +98,8 @@ def model_builder(features, labels, mode, params, config): feature_columns=feature_columns, logits_dimension=head.logits_dimension, features=training_features, - use_core_columns=use_core_libs) + use_core_columns=use_core_libs, + output_leaf_index=output_leaf_index) with ops.name_scope("gbdt", "gbdt_optimizer"): predictions_dict = gbdt_model.predict(mode) logits = predictions_dict["predictions"] @@ -127,6 +130,9 @@ def model_builder(features, labels, mode, params, config): labels=labels, train_op_fn=_train_op_fn, logits=logits) + if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict: + model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[ + gbdt_batch.LEAF_INDEX] if num_trees: if center_bias: num_trees += 1 diff --git a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc index b3fe38614e05801b223f0c96f7a70ce7e432a70b..9493c1a1394040db3b744f1b382b20bd5bd1988d 100644 --- a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc @@ -59,6 +59,7 @@ const char* kApplyDropoutAttributeName = "apply_dropout"; const char* kApplyAveragingAttributeName = "apply_averaging"; const char* kDropoutInfoOutputTensorName = "drop_out_tree_indices_weights"; const char* kPredictionsTensorName = "predictions"; +const char* kLeafIndexTensorName = "leaf_index"; void CalculateTreesToInclude( const boosted_trees::trees::DecisionTreeEnsembleConfig& config, @@ -170,15 +171,22 @@ class GradientTreesPredictionOp : public OpKernel { core::ScopedUnref unref_me(ensemble_resource); if (use_locking_) { tf_shared_lock l(*ensemble_resource->get_mutex()); - DoCompute(context, ensemble_resource); + DoCompute(context, ensemble_resource, + /*return_output_leaf_index=*/false); } else { - DoCompute(context, ensemble_resource); + DoCompute(context, ensemble_resource, + /*return_output_leaf_index=*/false); } } - private: - void DoCompute(OpKernelContext* context, - DecisionTreeEnsembleResource* ensemble_resource) { + protected: + // return_output_leaf_index is a boolean variable indicating whether to output + // leaf index in prediction. Though this class invokes only with this param + // value as false, the subclass GradientTreesPredictionVerboseOp will invoke + // with the true value. + virtual void DoCompute(OpKernelContext* context, + DecisionTreeEnsembleResource* ensemble_resource, + const bool return_output_leaf_index) { // Read dense float features list; OpInputList dense_float_features_list; OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures( @@ -267,6 +275,14 @@ class GradientTreesPredictionOp : public OpKernel { &output_predictions_t)); auto output_predictions = output_predictions_t->matrix(); + // Allocate output leaf index matrix. + Tensor* output_leaf_index_t = nullptr; + if (return_output_leaf_index) { + OP_REQUIRES_OK(context, context->allocate_output( + kLeafIndexTensorName, + {batch_size, ensemble_resource->num_trees()}, + &output_leaf_index_t)); + } // Run predictor. thread::ThreadPool* const worker_threads = context->device()->tensorflow_cpu_worker_threads()->workers; @@ -288,11 +304,13 @@ class GradientTreesPredictionOp : public OpKernel { i, weight * (num_ensembles - i + start_averaging) / num_ensembles); } MultipleAdditiveTrees::Predict(adjusted, trees_to_include, batch_features, - worker_threads, output_predictions); + worker_threads, output_predictions, + output_leaf_index_t); } else { MultipleAdditiveTrees::Predict( ensemble_resource->decision_tree_ensemble(), trees_to_include, - batch_features, worker_threads, output_predictions); + batch_features, worker_threads, output_predictions, + output_leaf_index_t); } // Output dropped trees and original weights. @@ -302,7 +320,6 @@ class GradientTreesPredictionOp : public OpKernel { {2, static_cast(dropped_trees.size())}, &output_dropout_info_t)); auto output_dropout_info = output_dropout_info_t->matrix(); - for (int32 i = 0; i < dropped_trees.size(); ++i) { output_dropout_info(0, i) = dropped_trees[i]; output_dropout_info(1, i) = original_weights[i]; @@ -326,6 +343,27 @@ class GradientTreesPredictionOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("GradientTreesPrediction").Device(DEVICE_CPU), GradientTreesPredictionOp); +// GradientTreesPredictionVerboseOp is derived from GradientTreesPredictionOp +// and have an additional output of tensor of rank 2 containing leaf ids for +// each tree where an instance ended up with. +class GradientTreesPredictionVerboseOp : public GradientTreesPredictionOp { + public: + explicit GradientTreesPredictionVerboseOp(OpKernelConstruction* const context) + : GradientTreesPredictionOp(context) {} + + protected: + void DoCompute(OpKernelContext* context, + DecisionTreeEnsembleResource* ensemble_resource, + bool return_output_leaf_index) override { + GradientTreesPredictionOp::DoCompute(context, ensemble_resource, + /*return_output_leaf_index=*/true); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("GradientTreesPredictionVerbose").Device(DEVICE_CPU), + GradientTreesPredictionVerboseOp); + class GradientTreesPartitionExamplesOp : public OpKernel { public: explicit GradientTreesPartitionExamplesOp(OpKernelConstruction* const context) diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc index 43b00d4c6dc2e0066810012292874314215c41be..c9223afeab233497bce9f680bd44bd10ccfc6491 100644 --- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc @@ -26,7 +26,8 @@ void MultipleAdditiveTrees::Predict( const std::vector& trees_to_include, const boosted_trees::utils::BatchFeatures& features, tensorflow::thread::ThreadPool* const worker_threads, - tensorflow::TTypes::Matrix output_predictions) { + tensorflow::TTypes::Matrix output_predictions, + Tensor* const output_leaf_index) { // Zero out predictions as the model is additive. output_predictions.setZero(); @@ -38,8 +39,13 @@ void MultipleAdditiveTrees::Predict( // Lambda for doing a block of work. auto update_predictions = [&config, &features, &trees_to_include, - &output_predictions](int64 start, int64 end) { + &output_predictions, + &output_leaf_index](int64 start, int64 end) { auto examples_iterable = features.examples_iterable(start, end); + Tensor dummy_tensor(DT_INT32, TensorShape({1, 1})); + tensorflow::TTypes::Matrix output_leaf_index_mat = + output_leaf_index != nullptr ? output_leaf_index->matrix() + : dummy_tensor.matrix(); for (const auto& example : examples_iterable) { for (const int32 tree_idx : trees_to_include) { const boosted_trees::trees::DecisionTreeConfig& tree = @@ -47,6 +53,10 @@ void MultipleAdditiveTrees::Predict( const float tree_weight = config.tree_weights(tree_idx); const int leaf_idx = trees::DecisionTree::Traverse(tree, 0, example); QCHECK(leaf_idx >= 0) << "Invalid tree: " << tree.DebugString(); + // Checks if output leaf tree index is required. + if (output_leaf_index != nullptr) { + output_leaf_index_mat(example.example_idx, tree_idx) = leaf_idx; + } const auto& leaf_node = tree.nodes(leaf_idx); QCHECK(leaf_node.has_leaf()) << "Invalid leaf node: " << leaf_node.DebugString(); diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h index cc3dc226cdbc88fc7010ada1e7f0e6c0a3913c5f..940531c4ba4bcac19fa980deb091e55b48e0693b 100644 --- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h @@ -33,12 +33,17 @@ class MultipleAdditiveTrees { public: // Predict runs tree ensemble on the given batch and updates // output predictions accordingly, for the given list of trees. + // output_leaf_indices is a pointer to a 2 dimensional tensor. If it is not + // nullptr, this method fills output_leaf_indices with a per-tree leaf id + // where each of the instances from 'features' ended up in. Its shape is num + // examples X num of trees. static void Predict( const boosted_trees::trees::DecisionTreeEnsembleConfig& config, const std::vector& trees_to_include, const boosted_trees::utils::BatchFeatures& features, tensorflow::thread::ThreadPool* const worker_threads, - tensorflow::TTypes::Matrix output_predictions); + tensorflow::TTypes::Matrix output_predictions, + Tensor* const output_leaf_index); }; } // namespace models diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc index 4ca18bedb1054ef64c6d4b25bbad04842bab1a6a..462a9ac86fe51d07cfb958d9be49bef84811a52e 100644 --- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc @@ -62,7 +62,8 @@ TEST_F(MultipleAdditiveTreesTest, Empty) { tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test", kNumThreadsSingleThreaded); MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, + /*output_leaf_index=*/nullptr); EXPECT_EQ(0, output_matrix(0, 0)); EXPECT_EQ(0, output_matrix(1, 0)); } @@ -99,17 +100,38 @@ TEST_F(MultipleAdditiveTreesTest, SingleClass) { // Normal case. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, - batch_features_, &threads, output_matrix); + batch_features_, &threads, output_matrix, + /*output_leaf_index=*/nullptr); EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). } + // Normal case with leaf node. + { + // Initialize output leaf index tensor, since leaf index is positive in this + // case, initialize with the value of -1. Since there are 2 examples and + // there are 2 trees, initialize leaf output index by 2 * 2. + Tensor output_leaf_index_tensor(DT_INT32, TensorShape({2, 2})); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, + batch_features_, &threads, output_matrix, + &output_leaf_index_tensor); + EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). + EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). + EXPECT_FLOAT_EQ(0, output_leaf_index_tensor.matrix()( + 0, 0)); // 1st leaf for the first example + EXPECT_FLOAT_EQ(0, output_leaf_index_tensor.matrix()( + 1, 0)); // 1st leaf for the second example + EXPECT_FLOAT_EQ(2, output_leaf_index_tensor.matrix()( + 0, 1)); // 2nd leaf for the first example + EXPECT_FLOAT_EQ(1, output_leaf_index_tensor.matrix()( + 1, 1)); // 2nd leaf for the second example + } // Weighted case { DecisionTreeEnsembleConfig weighted = tree_ensemble_config; weighted.set_tree_weights(0, 6.0); weighted.set_tree_weights(1, 3.2); MultipleAdditiveTrees::Predict(weighted, {0, 1}, batch_features_, &threads, - output_matrix); + output_matrix, nullptr); // -0.4 (bias) + 0.2 (leaf 2). EXPECT_FLOAT_EQ(-0.4f * 6 + 0.2 * 3.2, output_matrix(0, 0)); // -0.4 (bias) + 0.9 (leaf 1). @@ -118,21 +140,21 @@ TEST_F(MultipleAdditiveTreesTest, SingleClass) { // Drop first tree. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {1}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 0)); // 0.2 (leaf 2). EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 1). } // Drop second tree. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias). EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias). } // Drop all trees. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0)); EXPECT_FLOAT_EQ(0.0, output_matrix(1, 0)); } @@ -172,7 +194,8 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { // Normal case. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, - batch_features_, &threads, output_matrix); + batch_features_, &threads, output_matrix, + nullptr); EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias) EXPECT_FLOAT_EQ(-0.5f, output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2) EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1) @@ -184,7 +207,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { weighted.set_tree_weights(0, 6.0); weighted.set_tree_weights(1, 3.2); MultipleAdditiveTrees::Predict(weighted, {0, 1}, batch_features_, &threads, - output_matrix); + output_matrix, nullptr); // bias EXPECT_FLOAT_EQ(-0.4f * 6, output_matrix(0, 0)); // bias + leaf 2 @@ -197,7 +220,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { // Dropout first tree. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {1}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0)); EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 1)); // 0.2 (leaf 2) EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 2) @@ -206,7 +229,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { // Dropout second tree. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias) EXPECT_FLOAT_EQ(-0.7f, output_matrix(0, 1)); // -0.7 (bias) EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias) @@ -215,7 +238,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { // Drop both trees. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 0)); EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 1)); EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 0)); @@ -258,7 +281,8 @@ TEST_F(MultipleAdditiveTreesTest, DenseLeaves) { // Normal case. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, - batch_features_, &threads, output_matrix); + batch_features_, &threads, output_matrix, + nullptr); EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (tree1) + 0.2 (leaf 2) EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 1)); // -0.7 (tree1) + 0.3 (leaf 2) EXPECT_FLOAT_EQ(3.4f, output_matrix(0, 2)); // 3.0 -(tree1) + 0.4 (leaf 2) diff --git a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc index d66f645f62aba84261337eb37d6e3204930f8f15..6491d58794332e9417951753532e018aafb652b1 100644 --- a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc @@ -40,6 +40,24 @@ static Status ApplyGradientTreesPredictionShapeFn(InferenceContext* c) { return Status::OK(); } +static Status ApplyGradientTreesPredictionVerboseShapeFn(InferenceContext* c) { + string learner_config_str; + c->GetAttr("learner_config", &learner_config_str).IgnoreError(); + LearnerConfig learner_config; + ParseProtoUnlimited(&learner_config, learner_config_str); + + bool reduce_dim; + c->GetAttr("reduce_dim", &reduce_dim).IgnoreError(); + // Sets the shape of the output as a matrix. + c->set_output(0, {c->Matrix(InferenceContext::kUnknownDim, + reduce_dim ? learner_config.num_classes() - 1 + : learner_config.num_classes())}); + c->set_output(1, {c->UnknownShape()}); + c->set_output(2, {c->Matrix(InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim)}); + return Status::OK(); +} + REGISTER_OP("GradientTreesPrediction") .Attr("learner_config: string") .Attr("num_dense_float_features: int >= 0") @@ -90,6 +108,58 @@ drop_out_tree_indices_weights: Tensor of Rank 2 containing dropped trees indices and original weights of those trees during prediction. )doc"); +REGISTER_OP("GradientTreesPredictionVerbose") + .Attr("learner_config: string") + .Attr("num_dense_float_features: int >= 0") + .Attr("num_sparse_float_features: int >= 0") + .Attr("num_sparse_int_features: int >= 0") + .Attr("use_locking: bool = false") + .Attr("apply_dropout: bool") + .Attr("apply_averaging: bool") + .Attr("center_bias: bool") + .Attr("reduce_dim: bool") + .Input("tree_ensemble_handle: resource") + .Input("seed: int64") + .Input("dense_float_features: num_dense_float_features * float") + .Input("sparse_float_feature_indices: num_sparse_float_features * int64") + .Input("sparse_float_feature_values: num_sparse_float_features * float") + .Input("sparse_float_feature_shapes: num_sparse_float_features * int64") + .Input("sparse_int_feature_indices: num_sparse_int_features * int64") + .Input("sparse_int_feature_values: num_sparse_int_features * int64") + .Input("sparse_int_feature_shapes: num_sparse_int_features * int64") + .Output("predictions: float") + .Output("drop_out_tree_indices_weights: float") + .Output("leaf_index: int32") + .SetShapeFn(ApplyGradientTreesPredictionVerboseShapeFn) + .Doc(R"doc( +Runs multiple additive regression forests predictors on input instances +and computes the final prediction for each class, and outputs a matrix of +leaf ids per each tree in an ensemble. + +learner_config: Config for the learner of type LearnerConfig proto. Prediction +ops for now uses only LearningRateDropoutDrivenConfig config from the learner. +num_dense_float_features: Number of dense float features. +num_sparse_float_features: Number of sparse float features. +num_sparse_int_features: Number of sparse int features. +use_locking: Whether to use locking. +seed: random seed to be used for dropout. +reduce_dim: whether to reduce the dimension (legacy impl) or not. +apply_dropout: whether to apply dropout during prediction. +apply_averaging: whether averaging of tree ensembles should take place. If set +to true, will be based on AveragingConfig from learner_config. +tree_ensemble_handle: The handle to the tree ensemble. +dense_float_features: Rank 2 Tensors containing dense float feature values. +sparse_float_feature_indices: Rank 2 Tensors containing sparse float indices. +sparse_float_feature_values: Rank 1 Tensors containing sparse float values. +sparse_float_feature_shapes: Rank 1 Tensors containing sparse float shapes. +sparse_int_feature_indices: Rank 2 Tensors containing sparse int indices. +sparse_int_feature_values: Rank 1 Tensors containing sparse int values. +sparse_int_feature_shapes: Rank 1 Tensors containing sparse int shapes. +predictions: Rank 2 Tensor containing predictions per example per class. +drop_out_tree_indices_weights: Tensor of Rank 2 containing dropped trees indices +leaf_index: tensor of rank 2 containing leaf ids for each tree where an instance ended up. +)doc"); + REGISTER_OP("GradientTreesPartitionExamples") .Attr("num_dense_float_features: int >= 0") .Attr("num_sparse_float_features: int >= 0") diff --git a/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py b/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py index 58f0d36b0f78eeed6abcec1c4fa696f4ccffa615..7f6e55ae5888fc4ef50e34690d61c3ed303e971a 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py @@ -21,4 +21,5 @@ from __future__ import print_function from tensorflow.contrib.boosted_trees.python.ops import boosted_trees_ops_loader from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_partition_examples from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_prediction +from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_prediction_verbose # pylint: enable=unused-import diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 5dd2e0c7f254f312932db6bb4a98734e46644e46..28fbf07fe46efd50d30df0085c155a64c5db3517 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -58,8 +58,20 @@ NUM_LAYERS_ATTEMPTED = "num_layers" NUM_TREES_ATTEMPTED = "num_trees" NUM_USED_HANDLERS = "num_used_handlers" USED_HANDLERS_MASK = "used_handlers_mask" +LEAF_INDEX = "leaf_index" _FEATURE_NAME_TEMPLATE = "%s_%d" +# Keys in Training state. +_NUM_LAYER_EXAMPLES = "num_layer_examples" +_NUM_LAYER_STEPS = "num_layer_steps" +_NUM_LAYERS = "num_layers" +_ACTIVE_TREE = "active_tree" +_ACTIVE_LAYER = "active_layer" +_CONTINUE_CENTERING = "continue_centering" +_BIAS_STATS_ACCUMULATOR = "bias_stats_accumulator" +_STEPS_ACCUMULATOR = "steps_accumulator" +_HANDLERS = "handlers" + def _get_column_by_index(tensor, indices): """Returns columns from a 2-D tensor by index.""" @@ -71,18 +83,24 @@ def _get_column_by_index(tensor, indices): return array_ops.reshape(array_ops.gather(p_flat, i_flat), [shape[0], -1]) -def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats, - used_handlers): +def _make_predictions_dict(stamp, + logits, + partition_ids, + ensemble_stats, + used_handlers, + leaf_index=None): """Returns predictions for the given logits and n_classes. Args: stamp: The ensemble stamp. - logits: A rank 2 `Tensor` with shape [batch_size, n_classes - 1]. - that contains predictions when no dropout was applied. + logits: A rank 2 `Tensor` with shape [batch_size, n_classes - 1]. that + contains predictions when no dropout was applied. partition_ids: A rank 1 `Tensor` with shape [batch_size]. ensemble_stats: A TreeEnsembleStatsOp result tuple. used_handlers: A TreeEnsembleUsedHandlerOp result tuple of an int and a - boolean mask.. + boolean mask. + leaf_index: A rank 2 `Tensor` with shape [batch_size, number of trees]. that + contains leaf id for each example prediction. Returns: A dict of predictions. @@ -95,6 +113,8 @@ def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats, result[NUM_TREES_ATTEMPTED] = ensemble_stats.attempted_trees result[NUM_USED_HANDLERS] = used_handlers.num_used_handlers result[USED_HANDLERS_MASK] = used_handlers.used_handlers_mask + if leaf_index is not None: + result[LEAF_INDEX] = leaf_index return result @@ -268,7 +288,8 @@ class GradientBoostedDecisionTreeModel(object): features, logits_dimension, feature_columns=None, - use_core_columns=False): + use_core_columns=False, + output_leaf_index=False): """Construct a new GradientBoostedDecisionTreeModel function. Args: @@ -276,13 +297,15 @@ class GradientBoostedDecisionTreeModel(object): num_ps_replicas: Number of parameter server replicas, can be 0. ensemble_handle: A handle to the ensemble variable. center_bias: Whether to center the bias before growing trees. - examples_per_layer: Number of examples to accumulate before growing - a tree layer. It can also be a function that computes the number of - examples based on the depth of the layer that's being built. + examples_per_layer: Number of examples to accumulate before growing a tree + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. learner_config: A learner config. features: `dict` of `Tensor` objects. logits_dimension: An int, the dimension of logits. feature_columns: A list of feature columns. + output_leaf_index: A boolean variable indicating whether to output leaf + index into predictions dictionary. Raises: ValueError: if inputs are not valid. @@ -313,6 +336,19 @@ class GradientBoostedDecisionTreeModel(object): learner_config.multi_class_strategy = ( learner_pb2.LearnerConfig.DIAGONAL_HESSIAN) + if logits_dimension == 1 or learner_config.multi_class_strategy == ( + learner_pb2.LearnerConfig.TREE_PER_CLASS): + self._gradient_shape = tensor_shape.scalar() + self._hessian_shape = tensor_shape.scalar() + else: + self._gradient_shape = tensor_shape.TensorShape([logits_dimension]) + if (learner_config.multi_class_strategy == + learner_pb2.LearnerConfig.FULL_HESSIAN): + self._hessian_shape = tensor_shape.TensorShape( + ([logits_dimension, logits_dimension])) + else: + # Diagonal hessian strategy. + self._hessian_shape = tensor_shape.TensorShape(([logits_dimension])) if (learner_config.growing_mode == learner_pb2.LearnerConfig.GROWING_MODE_UNSPECIFIED): learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER @@ -359,6 +395,7 @@ class GradientBoostedDecisionTreeModel(object): self._learner_config.multi_class_strategy == learner_pb2.LearnerConfig.TREE_PER_CLASS and learner_config.num_classes == 2) + self._output_leaf_index = output_leaf_index def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode): """Runs prediction and returns a dictionary of the prediction results. @@ -388,22 +425,44 @@ class GradientBoostedDecisionTreeModel(object): # Make sure ensemble stats run. This will check that the ensemble has # the right stamp. with ops.control_dependencies(ensemble_stats): - predictions, _ = prediction_ops.gradient_trees_prediction( - ensemble_handle, - seed, - self._dense_floats, - self._sparse_float_indices, - self._sparse_float_values, - self._sparse_float_shapes, - self._sparse_int_indices, - self._sparse_int_values, - self._sparse_int_shapes, - learner_config=self._learner_config_serialized, - apply_dropout=apply_dropout, - apply_averaging=mode != learn.ModeKeys.TRAIN, - use_locking=True, - center_bias=self._center_bias, - reduce_dim=self._reduce_dim) + leaf_index = None + # Only used in infer (predict), not used in train and eval. + if self._output_leaf_index and mode == learn.ModeKeys.INFER: + predictions, _, leaf_index = ( + prediction_ops).gradient_trees_prediction_verbose( + ensemble_handle, + seed, + self._dense_floats, + self._sparse_float_indices, + self._sparse_float_values, + self._sparse_float_shapes, + self._sparse_int_indices, + self._sparse_int_values, + self._sparse_int_shapes, + learner_config=self._learner_config_serialized, + apply_dropout=apply_dropout, + apply_averaging=mode != learn.ModeKeys.TRAIN, + use_locking=True, + center_bias=self._center_bias, + reduce_dim=self._reduce_dim) + else: + leaf_index = None + predictions, _ = prediction_ops.gradient_trees_prediction( + ensemble_handle, + seed, + self._dense_floats, + self._sparse_float_indices, + self._sparse_float_values, + self._sparse_float_shapes, + self._sparse_int_indices, + self._sparse_int_values, + self._sparse_int_shapes, + learner_config=self._learner_config_serialized, + apply_dropout=apply_dropout, + apply_averaging=mode != learn.ModeKeys.TRAIN, + use_locking=True, + center_bias=self._center_bias, + reduce_dim=self._reduce_dim) partition_ids = prediction_ops.gradient_trees_partition_examples( ensemble_handle, self._dense_floats, @@ -416,7 +475,7 @@ class GradientBoostedDecisionTreeModel(object): use_locking=True) return _make_predictions_dict(ensemble_stamp, predictions, partition_ids, - ensemble_stats, used_handlers) + ensemble_stats, used_handlers, leaf_index) def predict(self, mode): """Returns predictions given the features and mode. @@ -487,14 +546,23 @@ class GradientBoostedDecisionTreeModel(object): return self._predict_and_return_dict(self._ensemble_handle, ensemble_stamp, mode) - def train(self, loss, predictions_dict, labels): - """Grows a new tree and adds it to the ensemble. + def _get_class_id(self, predictions_dict): + # Handle different multiclass strategies. + if (self._learner_config.multi_class_strategy == + learner_pb2.LearnerConfig.TREE_PER_CLASS and + self._logits_dimension != 1): + # Choose the class for which the tree is built (one vs rest). + return math_ops.to_int32( + predictions_dict[NUM_TREES_ATTEMPTED] % self._logits_dimension) + return constant_op.constant(-1, dtype=dtypes.int32) + + def update_stats(self, loss, predictions_dict): + """Update the accumulators with stats from this batch. Args: loss: A scalar tensor representing average loss of examples. predictions_dict: Dictionary of Rank 2 `Tensor` representing information about predictions per example. - labels: Rank 2 `Tensor` representing labels per example. Returns: An op that adds a new tree to the ensemble. @@ -507,6 +575,44 @@ class GradientBoostedDecisionTreeModel(object): self._dense_floats + self._sparse_float_indices + self._sparse_int_indices) worker_device = input_deps[0].device + # Create ensemble stats variables. + num_layer_examples = variables.Variable( + initial_value=array_ops.zeros([], dtypes.int64), + name="num_layer_examples", + trainable=False) + num_layer_steps = variables.Variable( + initial_value=array_ops.zeros([], dtypes.int64), + name="num_layer_steps", + trainable=False) + num_layers = variables.Variable( + initial_value=array_ops.zeros([], dtypes.int64), + name="num_layers", + trainable=False) + active_tree = variables.Variable( + initial_value=array_ops.zeros([], dtypes.int64), + name="active_tree", + trainable=False) + active_layer = variables.Variable( + initial_value=array_ops.zeros([], dtypes.int64), + name="active_layer", + trainable=False) + # Variable that becomes false once bias centering is done. + continue_centering = variables.Variable( + initial_value=self._center_bias, + name="continue_centering", + trainable=False) + # Create bias stats accumulator. + bias_stats_accumulator = stats_accumulator_ops.StatsAccumulator( + stamp_token=0, + gradient_shape=self._gradient_shape, + hessian_shape=self._hessian_shape, + name="BiasAccumulator") + # Create steps accumulator. + steps_accumulator = stats_accumulator_ops.StatsAccumulator( + stamp_token=0, + gradient_shape=tensor_shape.scalar(), + hessian_shape=tensor_shape.scalar(), + name="StepsAccumulator") # Get tensors relevant for training and form the loss. predictions = predictions_dict[PREDICTIONS] @@ -521,13 +627,10 @@ class GradientBoostedDecisionTreeModel(object): aggregation_method=None)[0] strategy = self._learner_config.multi_class_strategy - class_id = constant_op.constant(-1, dtype=dtypes.int32) + class_id = self._get_class_id(predictions_dict) # Handle different multiclass strategies. if strategy == learner_pb2.LearnerConfig.TREE_PER_CLASS: # We build one vs rest trees. - gradient_shape = tensor_shape.scalar() - hessian_shape = tensor_shape.scalar() - if self._logits_dimension == 1: # We have only 1 score, gradients is of shape [batch, 1]. hessians = gradients_impl.gradients( @@ -544,11 +647,6 @@ class GradientBoostedDecisionTreeModel(object): hessian_list = self._diagonal_hessian(gradients, predictions) # Assemble hessian list into a tensor. hessians = array_ops.stack(hessian_list, axis=1) - - # Choose the class for which the tree is built (one vs rest). - class_id = math_ops.to_int32( - predictions_dict[NUM_TREES_ATTEMPTED] % self._logits_dimension) - # Use class id tensor to get the column with that index from gradients # and hessians. squeezed_gradients = array_ops.squeeze( @@ -557,15 +655,10 @@ class GradientBoostedDecisionTreeModel(object): _get_column_by_index(hessians, class_id)) else: # Other multiclass strategies. - gradient_shape = tensor_shape.TensorShape([self._logits_dimension]) - if strategy == learner_pb2.LearnerConfig.FULL_HESSIAN: - hessian_shape = tensor_shape.TensorShape( - ([self._logits_dimension, self._logits_dimension])) hessian_list = self._full_hessian(gradients, predictions) else: # Diagonal hessian strategy. - hessian_shape = tensor_shape.TensorShape(([self._logits_dimension])) hessian_list = self._diagonal_hessian(gradients, predictions) squeezed_gradients = gradients @@ -573,7 +666,7 @@ class GradientBoostedDecisionTreeModel(object): squeezed_hessians = hessians # Get the weights for each example for quantiles calculation, - weights = self._get_weights(hessian_shape, squeezed_hessians) + weights = self._get_weights(self._hessian_shape, squeezed_hessians) # Create all handlers ensuring resources are evenly allocated across PS. fc_name_idx = 0 @@ -605,8 +698,8 @@ class GradientBoostedDecisionTreeModel(object): num_quantiles=num_quantiles, dense_float_column=self._dense_floats[dense_float_column_idx], name=fc_name, - gradient_shape=gradient_shape, - hessian_shape=hessian_shape, + gradient_shape=self._gradient_shape, + hessian_shape=self._hessian_shape, multiclass_strategy=strategy_tensor, init_stamp_token=init_stamp_token)) fc_name_idx += 1 @@ -628,8 +721,8 @@ class GradientBoostedDecisionTreeModel(object): self._sparse_float_values[sparse_float_column_idx], self._sparse_float_shapes[sparse_float_column_idx]), name=fc_name, - gradient_shape=gradient_shape, - hessian_shape=hessian_shape, + gradient_shape=self._gradient_shape, + hessian_shape=self._hessian_shape, multiclass_strategy=strategy_tensor, init_stamp_token=init_stamp_token)) fc_name_idx += 1 @@ -649,48 +742,12 @@ class GradientBoostedDecisionTreeModel(object): self._sparse_int_values[sparse_int_column_idx], self._sparse_int_shapes[sparse_int_column_idx]), name=fc_name, - gradient_shape=gradient_shape, - hessian_shape=hessian_shape, + gradient_shape=self._gradient_shape, + hessian_shape=self._hessian_shape, multiclass_strategy=strategy_tensor, init_stamp_token=init_stamp_token)) fc_name_idx += 1 - # Create steps accumulator. - steps_accumulator = stats_accumulator_ops.StatsAccumulator( - stamp_token=0, - gradient_shape=tensor_shape.scalar(), - hessian_shape=tensor_shape.scalar(), - name="StepsAccumulator") - - # Create bias stats accumulator. - bias_stats_accumulator = stats_accumulator_ops.StatsAccumulator( - stamp_token=0, - gradient_shape=gradient_shape, - hessian_shape=hessian_shape, - name="BiasAccumulator") - - # Create ensemble stats variables. - num_layer_examples = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), - name="num_layer_examples", - trainable=False) - num_layer_steps = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), - name="num_layer_steps", - trainable=False) - num_layers = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), - name="num_layers", - trainable=False) - active_tree = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), - name="active_tree", - trainable=False) - active_layer = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), - name="active_layer", - trainable=False) - # Create ensemble stats summaries. summary.scalar("layer_stats/num_examples", num_layer_examples) summary.scalar("layer_stats/num_steps", num_layer_steps) @@ -699,16 +756,13 @@ class GradientBoostedDecisionTreeModel(object): # Update bias stats. stats_update_ops = [] - continue_centering = variables.Variable( - initial_value=self._center_bias, - name="continue_centering", - trainable=False) + stats_update_ops.append( control_flow_ops.cond( continue_centering, - self._make_update_bias_stats_fn(ensemble_stamp, predictions, - gradients, bias_stats_accumulator), - control_flow_ops.no_op)) + self._make_update_bias_stats_fn( + ensemble_stamp, predictions, gradients, + bias_stats_accumulator), control_flow_ops.no_op)) # Update handler stats. handler_reads = collections.OrderedDict() @@ -765,8 +819,8 @@ class GradientBoostedDecisionTreeModel(object): lambda: active_handlers)) # Prepare empty gradients and hessians when handlers are not ready. - empty_hess_shape = [1] + hessian_shape.as_list() - empty_grad_shape = [1] + gradient_shape.as_list() + empty_hess_shape = [1] + self._hessian_shape.as_list() + empty_grad_shape = [1] + self._gradient_shape.as_list() empty_gradients = constant_op.constant( [], dtype=dtypes.float32, shape=empty_grad_shape) @@ -788,34 +842,66 @@ class GradientBoostedDecisionTreeModel(object): per_handler_updates, ensemble_stamp, worker_device) for update in update_results.values(): stats_update_ops += update + + training_state = { + _NUM_LAYER_EXAMPLES: num_layer_examples, + _NUM_LAYER_STEPS: num_layer_steps, + _NUM_LAYERS: num_layers, + _ACTIVE_TREE: active_tree, + _ACTIVE_LAYER: active_layer, + _CONTINUE_CENTERING: continue_centering, + _BIAS_STATS_ACCUMULATOR: bias_stats_accumulator, + _STEPS_ACCUMULATOR: steps_accumulator, + _HANDLERS: handlers + } + return stats_update_ops, training_state + + def increment_step_counter_and_maybe_update_ensemble( + self, predictions_dict, batch_size, training_state): + """Increments number of visited examples and grows the ensemble. + + If the number of visited examples reaches the target examples_per_layer, + ensemble is updated. + + Args: + predictions_dict: Dictionary of Rank 2 `Tensor` representing information + about predictions per example. + batch_size: Number of examples in the batch. + training_state: `dict` returned by update_stats. + + Returns: + An op that updates the counters and potientially grows the ensemble. + """ + ensemble_stamp = predictions_dict[ENSEMBLE_STAMP] # Accumulate a step after updating stats. - batch_size = math_ops.cast(array_ops.shape(labels)[0], dtypes.float32) - with ops.control_dependencies(stats_update_ops): - add_step_op = steps_accumulator.add(ensemble_stamp, [0], [[0, 0]], - [batch_size], [1.0]) - # Determine learning rate. - learning_rate_tuner = self._learner_config.learning_rate_tuner.WhichOneof( - "tuner") - if learning_rate_tuner == "fixed" or learning_rate_tuner == "dropout": - tuner = getattr(self._learner_config.learning_rate_tuner, - learning_rate_tuner) - learning_rate = tuner.learning_rate - else: - # TODO(nponomareva, soroush) do the line search. - raise ValueError("Line search learning rate is not yet supported.") + num_layer_examples = training_state[_NUM_LAYER_EXAMPLES] + num_layer_steps = training_state[_NUM_LAYER_STEPS] + num_layers = training_state[_NUM_LAYERS] + active_tree = training_state[_ACTIVE_TREE] + active_layer = training_state[_ACTIVE_LAYER] + continue_centering = training_state[_CONTINUE_CENTERING] + bias_stats_accumulator = training_state[_BIAS_STATS_ACCUMULATOR] + steps_accumulator = training_state[_STEPS_ACCUMULATOR] + handlers = training_state[_HANDLERS] + add_step_op = steps_accumulator.add( + ensemble_stamp, [0], [[0, 0]], [batch_size], [1.0]) # After adding the step, decide if further processing is needed. ensemble_update_ops = [add_step_op] + class_id = self._get_class_id(predictions_dict) + with ops.control_dependencies([add_step_op]): if self._is_chief: dropout_seed = predictions_dict[NUM_TREES_ATTEMPTED] # Get accumulated steps and examples for the current layer. - _, _, _, _, acc_examples, acc_steps = steps_accumulator.serialize() + _, _, _, _, acc_examples, acc_steps = ( + steps_accumulator.serialize()) acc_examples = math_ops.cast(acc_examples[0], dtypes.int64) acc_steps = math_ops.cast(acc_steps[0], dtypes.int64) - ensemble_update_ops.append(num_layer_examples.assign(acc_examples)) + ensemble_update_ops.append( + num_layer_examples.assign(acc_examples)) ensemble_update_ops.append(num_layer_steps.assign(acc_steps)) # Determine whether we need to update tree ensemble. examples_per_layer = self._examples_per_layer @@ -824,139 +910,33 @@ class GradientBoostedDecisionTreeModel(object): ensemble_update_ops.append( control_flow_ops.cond( acc_examples >= examples_per_layer, - self._make_update_ensemble_fn( - ensemble_stamp, steps_accumulator, bias_stats_accumulator, - continue_centering, learning_rate, handlers, num_layers, - active_tree, active_layer, dropout_seed, class_id), + self.make_update_ensemble_fn( + ensemble_stamp, steps_accumulator, + bias_stats_accumulator, continue_centering, + handlers, num_layers, active_tree, + active_layer, dropout_seed, class_id), control_flow_ops.no_op)) - # Calculate the loss to be reported. # Note, the loss is calculated from the prediction considering dropouts, so # that the value might look staggering over steps when the dropout ratio is # high. eval_loss might be referred instead in the aspect of convergence. return control_flow_ops.group(*ensemble_update_ops) - def _get_weights(self, hessian_shape, hessians): - """Derives weights to be used based on hessians and multiclass strategy.""" - if hessian_shape == tensor_shape.scalar(): - # This is tree per class. - weights = hessians - elif len(hessian_shape.dims) == 1: - # This is diagonal hessian. - weights = math_ops.reduce_sum(hessians, axis=1) - else: - # This is full hessian. - weights = math_ops.trace(hessians) - return weights - - def _full_hessian(self, grads, predictions): - """Prepares hessians for full-hessian multiclass strategy.""" - # Because of - # https://github.com/tensorflow/tensorflow/issues/675, we can't just - # compute the full hessian with a single call to gradients, but instead - # must compute it row-by-row. - gradients_list = array_ops.unstack( - grads, num=self._logits_dimension, axis=1) - hessian_rows = [] - - for row in range(self._logits_dimension): - # If current row is i, K is number of classes,each row returns a tensor of - # size batch_size x K representing for each example dx_i dx_1, dx_i dx_2 - # etc dx_i dx_K - hessian_row = gradients_impl.gradients( - gradients_list[row], - predictions, - name="Hessian_%d" % row, - colocate_gradients_with_ops=False, - gate_gradients=0, - aggregation_method=None) - - # hessian_row is of dimension 1, batch_size, K, => trim first dimension - # to get batch_size x K - hessian_row = array_ops.squeeze(array_ops.unstack(hessian_row), [0]) - hessian_rows.append(hessian_row) - return hessian_rows - - def _diagonal_hessian(self, grads, predictions): - """Prepares hessians for diagonal-hessian multiclass mode.""" - diag_hessian_list = [] - - gradients_list = array_ops.unstack( - grads, num=self._logits_dimension, axis=1) - - for row, row_grads in enumerate(gradients_list): - # If current row is i, K is number of classes,each row returns a tensor of - # size batch_size x K representing for each example dx_i dx_1, dx_1 dx_2 - # etc dx_i dx_K - hessian_row = gradients_impl.gradients( - row_grads, - predictions, - name="Hessian_%d" % row, - colocate_gradients_with_ops=False, - gate_gradients=0, - aggregation_method=None) - - # hessian_row is of dimension 1, batch_size, K, => trim first dimension - # to get batch_size x K - hessian_row = array_ops.squeeze(array_ops.unstack(hessian_row), [0]) - - # Get dx_i^2 for the whole batch. - elem = array_ops.transpose(hessian_row)[row] - diag_hessian_list.append(elem) - - return diag_hessian_list - - def _get_replica_device_setter(self, worker_device): - """Creates a replica device setter.""" - ps_tasks = self._num_ps_replicas - ps_ops = [ - "Variable", - "VariableV2", - "DecisionTreeEnsembleResourceHandleOp", - "StatsAccumulatorScalarResourceHandleOp", - "StatsAccumulatorTensorResourceHandleOp", - ] - ps_strategy = _OpRoundRobinStrategy(ps_ops, ps_tasks) - return device_setter.replica_device_setter( - worker_device=worker_device, - ps_tasks=ps_tasks, - merge_devices=True, - ps_ops=ps_ops, - ps_strategy=ps_strategy) - - def _make_update_bias_stats_fn(self, ensemble_stamp, predictions, gradients, - bias_stats_accumulator): - """A method to create the function which updates the bias stats.""" - - def _update_bias_stats(): - """A method to update the bias stats.""" - # Get reduced gradients and hessians. - grads_sum = math_ops.reduce_sum(gradients, 0) - hess = gradients_impl.gradients( - grads_sum, - predictions, - name="Hessians", - colocate_gradients_with_ops=False, - gate_gradients=0, - aggregation_method=None)[0] - hess_sum = math_ops.reduce_sum(hess, 0) - - # Accumulate gradients and hessians. - partition_ids = math_ops.range(self._logits_dimension) - feature_ids = array_ops.zeros( - [self._logits_dimension, 2], dtype=dtypes.int64) - - add_stats_op = bias_stats_accumulator.add( - ensemble_stamp, partition_ids, feature_ids, grads_sum, hess_sum) - return control_flow_ops.group(*[add_stats_op], name="update_bias_stats") - - return _update_bias_stats - - def _make_update_ensemble_fn(self, ensemble_stamp, steps_accumulator, - bias_stats_accumulator, continue_centering, - learning_rate, handlers, num_layers, active_tree, - active_layer, dropout_seed, class_id): + def make_update_ensemble_fn(self, ensemble_stamp, steps_accumulator, + bias_stats_accumulator, continue_centering, + handlers, num_layers, active_tree, active_layer, + dropout_seed, class_id): """A method to create the function which updates the tree ensemble.""" + # Determine learning rate. + learning_rate_tuner = self._learner_config.learning_rate_tuner.WhichOneof( + "tuner") + if learning_rate_tuner == "fixed" or learning_rate_tuner == "dropout": + tuner = getattr(self._learner_config.learning_rate_tuner, + learning_rate_tuner) + learning_rate = tuner.learning_rate + else: + # TODO(nponomareva, soroush) do the line search. + raise ValueError("Line search learning rate is not yet supported.") def _update_ensemble(): """A method to update the tree ensemble.""" @@ -1075,3 +1055,140 @@ class GradientBoostedDecisionTreeModel(object): def get_number_of_trees_tensor(self): return self._finalized_trees, self._attempted_trees + + def train(self, loss, predictions_dict, labels): + """Updates the accumalator stats and grows the ensemble. + + Args: + loss: A scalar tensor representing average loss of examples. + predictions_dict: Dictionary of Rank 2 `Tensor` representing information + about predictions per example. + labels: Rank 2 `Tensor` representing labels per example. + + Returns: + An op that adds a new tree to the ensemble. + + Raises: + ValueError: if inputs are not valid. + """ + batch_size = math_ops.cast(array_ops.shape(labels)[0], dtypes.float32) + update_op, handlers = self.update_stats(loss, predictions_dict) + with ops.control_dependencies(update_op): + return self.increment_step_counter_and_maybe_update_ensemble( + predictions_dict, batch_size, handlers) + + def _get_weights(self, hessian_shape, hessians): + """Derives weights to be used based on hessians and multiclass strategy.""" + if hessian_shape == tensor_shape.scalar(): + # This is tree per class. + weights = hessians + elif len(hessian_shape.dims) == 1: + # This is diagonal hessian. + weights = math_ops.reduce_sum(hessians, axis=1) + else: + # This is full hessian. + weights = math_ops.trace(hessians) + return weights + + def _full_hessian(self, grads, predictions): + """Prepares hessians for full-hessian multiclass strategy.""" + # Because of + # https://github.com/tensorflow/tensorflow/issues/675, we can't just + # compute the full hessian with a single call to gradients, but instead + # must compute it row-by-row. + gradients_list = array_ops.unstack( + grads, num=self._logits_dimension, axis=1) + hessian_rows = [] + + for row in range(self._logits_dimension): + # If current row is i, K is number of classes,each row returns a tensor of + # size batch_size x K representing for each example dx_i dx_1, dx_i dx_2 + # etc dx_i dx_K + hessian_row = gradients_impl.gradients( + gradients_list[row], + predictions, + name="Hessian_%d" % row, + colocate_gradients_with_ops=False, + gate_gradients=0, + aggregation_method=None) + + # hessian_row is of dimension 1, batch_size, K, => trim first dimension + # to get batch_size x K + hessian_row = array_ops.squeeze(array_ops.unstack(hessian_row), [0]) + hessian_rows.append(hessian_row) + return hessian_rows + + def _diagonal_hessian(self, grads, predictions): + """Prepares hessians for diagonal-hessian multiclass mode.""" + diag_hessian_list = [] + + gradients_list = array_ops.unstack( + grads, num=self._logits_dimension, axis=1) + + for row, row_grads in enumerate(gradients_list): + # If current row is i, K is number of classes,each row returns a tensor of + # size batch_size x K representing for each example dx_i dx_1, dx_1 dx_2 + # etc dx_i dx_K + hessian_row = gradients_impl.gradients( + row_grads, + predictions, + name="Hessian_%d" % row, + colocate_gradients_with_ops=False, + gate_gradients=0, + aggregation_method=None) + + # hessian_row is of dimension 1, batch_size, K, => trim first dimension + # to get batch_size x K + hessian_row = array_ops.squeeze(array_ops.unstack(hessian_row), [0]) + + # Get dx_i^2 for the whole batch. + elem = array_ops.transpose(hessian_row)[row] + diag_hessian_list.append(elem) + + return diag_hessian_list + + def _get_replica_device_setter(self, worker_device): + """Creates a replica device setter.""" + ps_tasks = self._num_ps_replicas + ps_ops = [ + "Variable", + "VariableV2", + "DecisionTreeEnsembleResourceHandleOp", + "StatsAccumulatorScalarResourceHandleOp", + "StatsAccumulatorTensorResourceHandleOp", + ] + ps_strategy = _OpRoundRobinStrategy(ps_ops, ps_tasks) + return device_setter.replica_device_setter( + worker_device=worker_device, + ps_tasks=ps_tasks, + merge_devices=True, + ps_ops=ps_ops, + ps_strategy=ps_strategy) + + def _make_update_bias_stats_fn(self, ensemble_stamp, predictions, gradients, + bias_stats_accumulator): + """A method to create the function which updates the bias stats.""" + + def _update_bias_stats(): + """A method to update the bias stats.""" + # Get reduced gradients and hessians. + grads_sum = math_ops.reduce_sum(gradients, 0) + hess = gradients_impl.gradients( + grads_sum, + predictions, + name="Hessians", + colocate_gradients_with_ops=False, + gate_gradients=0, + aggregation_method=None)[0] + hess_sum = math_ops.reduce_sum(hess, 0) + + # Accumulate gradients and hessians. + partition_ids = math_ops.range(self._logits_dimension) + feature_ids = array_ops.zeros( + [self._logits_dimension, 2], dtype=dtypes.int64) + + add_stats_op = bias_stats_accumulator.add( + ensemble_stamp, partition_ids, feature_ids, grads_sum, hess_sum) + return control_flow_ops.group(*[add_stats_op], name="update_bias_stats") + + return _update_bias_stats diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index 289fb195db109f25c9c4599dcfe076ac98298383..e3d4397fadcbaf148f7f6cfaca13e850639786cf 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -19,18 +19,15 @@ from __future__ import division from __future__ import print_function from google.protobuf import text_format - from tensorflow.contrib import layers from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.proto import tree_config_pb2 from tensorflow.contrib.boosted_trees.python.ops import model_ops from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch from tensorflow.contrib.boosted_trees.python.utils import losses - -from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn - +from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util @@ -782,6 +779,118 @@ class GbdtTest(test_util.TensorFlowTestCase): [[0.25], [0.25], [0.25], [0.25]]) self.assertAllClose(predictions_dict["partition_ids"], [0, 0, 0, 0]) + def testPredictFnWithLeafIndexAdvancedLeft(self): + """Tests the predict function with output leaf ids.""" + with self.test_session() as sess: + # Create ensemble with one bias node. + ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + text_format.Merge( + """ + trees { + nodes { + dense_float_binary_split { + threshold: 1.0 + left_id: 1 + right_id: 2 + } + node_metadata { + gain: 0 + } + } + nodes { + leaf { + vector { + value: 0.25 + } + } + } + nodes { + leaf { + vector { + value: 0.15 + } + } + } + } + trees { + nodes { + dense_float_binary_split { + threshold: 0.99 + left_id: 1 + right_id: 2 + } + node_metadata { + gain: 00 + } + } + nodes { + leaf { + vector { + value: 0.25 + } + } + } + nodes { + leaf { + vector { + value: 0.23 + } + } + } + } + tree_weights: 1.0 + tree_weights: 1.0 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 1 + is_finalized: true + } + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 1 + is_finalized: true + }""", ensemble_config) + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=3, + tree_ensemble_config=ensemble_config.SerializeToString(), + name="tree_ensemble") + resources.initialize_resources(resources.shared_resources()).run() + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.regularization.l1 = 0 + learner_config.regularization.l2 = 0 + learner_config.constraints.max_tree_depth = 1 + learner_config.constraints.min_node_weight = 0 + features = {} + features["dense_float"] = array_ops.constant( + [[0.0], [1.0], [1.1], [2.0]], dtype=dtypes.float32) + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=False, + num_ps_replicas=0, + center_bias=True, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features, + output_leaf_index=True) + + # Create predict op. + mode = model_fn.ModeKeys.INFER + predictions_dict = sess.run(gbdt_model.predict(mode)) + self.assertEquals(predictions_dict["ensemble_stamp"], 3) + # here are how the numbers in expected results are calculated, + # 0.5 = 0.25 + 0.25 + # 0.48 = 0.25 + 0.23 + # 0.38 = 0.15 + 0.23 + # 0.38 = 0.15 + 0.23 + self.assertAllClose(predictions_dict["predictions"], + [[0.5], [0.48], [0.38], [0.38]]) + self.assertAllClose(predictions_dict["partition_ids"], [0, 0, 0, 0]) + self.assertAllClose(predictions_dict["leaf_index"], + [[1, 1], [1, 2], [2, 2], [2, 2]]) + def testTrainFnMulticlassFullHessian(self): """Tests the GBDT train for multiclass full hessian.""" with self.test_session() as sess: diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index 8ae493ba998bd882b5ef946f927ec1882d91f61d..38856417c0794da77ddbce3ad36977060e15b7a4 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -16,10 +16,13 @@ Visualization and inspection: @@dot_graph_from_checkpoint +@@list_objects @@object_metadata Managing dependencies: +@@capture_dependencies @@Checkpointable +@@CheckpointableBase @@CheckpointableObjectGraph @@NoDependency @@split_dependency @@ -39,12 +42,14 @@ from tensorflow.contrib.checkpoint.python.split_dependency import split_dependen from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph from tensorflow.python.training.checkpointable.base import Checkpointable +from tensorflow.python.training.checkpointable.base import CheckpointableBase from tensorflow.python.training.checkpointable.base import NoDependency from tensorflow.python.training.checkpointable.data_structures import List from tensorflow.python.training.checkpointable.data_structures import Mapping +from tensorflow.python.training.checkpointable.util import capture_dependencies +from tensorflow.python.training.checkpointable.util import list_objects from tensorflow.python.training.checkpointable.util import object_metadata from tensorflow.python.util.all_util import remove_undocumented remove_undocumented(module_name=__name__) - diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD index 42ba368531468b789a87429f88ca84937f9b909d..1a7a3759baa4a5559b4b70ff4f7467c41da9111f 100644 --- a/tensorflow/contrib/cloud/BUILD +++ b/tensorflow/contrib/cloud/BUILD @@ -74,3 +74,14 @@ tf_py_test( ], tags = ["manual"], ) + +tf_py_test( + name = "gcs_config_ops_test", + size = "small", + srcs = ["python/ops/gcs_config_ops_test.py"], + additional_deps = [ + ":cloud_py", + "//tensorflow/python:client_testlib", + ], + tags = ["manual"], +) diff --git a/tensorflow/contrib/cloud/__init__.py b/tensorflow/contrib/cloud/__init__.py index a6e13ea3ae938444b9ead0772e52fb8797a847da..ef7aa7624ce7b9b6480c4d088a2fb7678a7acc76 100644 --- a/tensorflow/contrib/cloud/__init__.py +++ b/tensorflow/contrib/cloud/__init__.py @@ -27,8 +27,9 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'BigQueryReader', - 'ConfigureColabSession', - 'ConfigureGcs', + 'BlockCacheParams', + 'configure_colab_session', + 'configure_gcs', 'ConfigureGcsHook', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD index 40160706f70e8fa8323005dd183770ed51c8c415..1311063ec023bdaa2588d6f1c826bf900f7dea09 100644 --- a/tensorflow/contrib/cloud/kernels/BUILD +++ b/tensorflow/contrib/cloud/kernels/BUILD @@ -79,6 +79,7 @@ tf_kernel_library( srcs = ["gcs_config_ops.cc"], visibility = ["//tensorflow:internal"], deps = [ + "//tensorflow/contrib/cloud:gcs_config_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/platform/cloud:curl_http_request", diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py index 8c8c5acb31af69b4f738a13c6548cdd31947d71a..95e7e744d34391a511cdba7702aad369b8d9d9c0 100644 --- a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py +++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py @@ -120,13 +120,18 @@ class ConfigureGcsHook(training.SessionRunHook): def begin(self): if self._credentials: self._credentials_placeholder = array_ops.placeholder(dtypes.string) - self._credentials_ops = gen_gcs_config_ops.gcs_configure_credentials( + self._credentials_op = gen_gcs_config_ops.gcs_configure_credentials( self._credentials_placeholder) + else: + self._credentials_op = None + if self._block_cache: self._block_cache_op = gen_gcs_config_ops.gcs_configure_block_cache( max_cache_size=self._block_cache.max_bytes, block_size=self._block_cache.block_size, max_staleness=self._block_cache.max_staleness) + else: + self._block_cache_op = None def after_create_session(self, session, coord): del coord diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9b6c056d6c8adfa50b95aefb8e9740631327a572 --- /dev/null +++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py @@ -0,0 +1,44 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the gcs_config_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.cloud.python.ops import gcs_config_ops +from tensorflow.python.platform import test + + +class GcsConfigOpsTest(test.TestCase): + + def testSetBlockCache(self): + cfg = gcs_config_ops.BlockCacheParams(max_bytes=1024*1024*1024) + with self.test_session() as sess: + gcs_config_ops.configure_gcs(sess, block_cache=cfg) + + def testConfigureGcsHook(self): + creds = {'client_id': 'fake_client', + 'refresh_token': 'fake_token', + 'client_secret': 'fake_secret', + 'type': 'authorized_user'} + hook = gcs_config_ops.ConfigureGcsHook(credentials=creds) + hook.begin() + with self.test_session() as sess: + sess.run = lambda _, feed_dict=None, options=None, run_metadata=None: None + hook.after_create_session(sess, None) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index d44e23aadc2ce5efb236eeba2ed148c698fe7528..8f521ffee4d31e090c13bac98290656d6e1d330e 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -36,6 +36,7 @@ except ImportError: _GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' +_ENDPOINTS_SEPARATOR = ',' _DEFAULT_ENV_VARIABLE = 'TPU_NAME' _DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL' @@ -69,8 +70,8 @@ class TPUClusterResolver(ClusterResolver): return _GKE_ENV_VARIABLE in os.environ @staticmethod - def _gkeMaster(): - return os.environ[_GKE_ENV_VARIABLE].split(',')[0] + def _gkeEndpoints(): + return os.environ[_GKE_ENV_VARIABLE] @staticmethod def _envVarFallback(): @@ -143,7 +144,7 @@ class TPUClusterResolver(ClusterResolver): # When using GKE with Cloud TPUs, the env variable will be set. if tpu is None: if in_gke: - tpu = self._gkeMaster() + tpu = self._gkeEndpoints() else: tpu = self._envVarFallback() @@ -173,7 +174,7 @@ class TPUClusterResolver(ClusterResolver): raise ImportError('googleapiclient and oauth2client must be installed ' 'before using the TPU cluster resolver. Execute: ' '`pip install --upgrade google-api-python-client` ' - 'and `pip install --upgrade oauth2lclient` to ' + 'and `pip install --upgrade oauth2client` to ' 'install with pip.') final_discovery_url = self._discoveryUrl() or discovery_url @@ -214,7 +215,7 @@ class TPUClusterResolver(ClusterResolver): ValueError: If none of the TPUs specified exists. """ if not self._shouldResolve(): - return self._tpu + return self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))[0] job_tasks = self.cluster_spec().job_tasks(self._job_name) if not job_tasks: @@ -256,6 +257,10 @@ class TPUClusterResolver(ClusterResolver): request = self._service.projects().locations().nodes().get(name=full_name) response = request.execute() + if 'state' in response and response['state'] != 'READY': + raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' % + (self._tpu, response['state'])) + if 'health' in response and response['health'] != 'HEALTHY': raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu, response['health'])) @@ -276,8 +281,12 @@ class TPUClusterResolver(ClusterResolver): # Case 3. return None # Case 2. - cluster_spec = {self._job_name: [self._tpu[len( - compat.as_bytes('grpc://')):]]} + cluster_spec = { + self._job_name: [ + x[len(compat.as_bytes('grpc://')):] + for x in self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR)) + ] + } if self._coordinator_address: # {1, 2}.a diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index 5fac55fd027fa2d100621e08a09e05cdb3a1b941..ad4f6432630be44a7de6e778f55f1fb7fd66f307 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -158,6 +158,50 @@ class TPUClusterResolverTest(test.TestCase): """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', + mock_request_compute_metadata) + def testUnhealthyCloudTpu(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', + 'port': '8470', + 'health': 'UNHEALTHY' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project=None, + zone=None, + tpu='test-tpu-1', + coordinator_name=None, + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + with self.assertRaises(RuntimeError): + tpu_cluster_resolver.cluster_spec() + + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', + mock_request_compute_metadata) + def testNotReadyCloudTpu(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', + 'port': '8470', + 'state': 'CREATING' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project=None, + zone=None, + tpu='test-tpu-1', + coordinator_name=None, + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + with self.assertRaises(RuntimeError): + tpu_cluster_resolver.cluster_spec() + def testSimpleSuccessfulRetrieval(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { @@ -358,13 +402,61 @@ class TPUClusterResolverTest(test.TestCase): compat.as_bytes('/bns/foo/bar'), tpu_cluster_resolver.master()) self.assertEqual(None, tpu_cluster_resolver.cluster_spec()) - def testGkeEnvironment(self): + def testGkeEnvironmentForDonut(self): os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470' - self.assertTrue('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' in os.environ) + + self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ) + self.assertTrue(TPUClusterResolver._inGke()) + self.assertEqual( + compat.as_bytes('grpc://10.120.27.5:8470'), + compat.as_bytes(TPUClusterResolver._gkeEndpoints())) + + tpu_cluster_resolver = TPUClusterResolver() + self.assertEqual( + compat.as_bytes('grpc://10.120.27.5:8470'), + compat.as_bytes(tpu_cluster_resolver.master())) + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { + name: 'worker' + tasks { key: 0 value: '10.120.27.5:8470' } + } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + + del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] + + def testGkeEnvironmentForPod(self): + os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = ('grpc://10.120.27.5:8470,' + 'grpc://10.120.27.6:8470,' + 'grpc://10.120.27.7:8470,' + 'grpc://10.120.27.8:8470') + + self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ) self.assertTrue(TPUClusterResolver._inGke()) + self.assertEqual( + compat.as_bytes('grpc://10.120.27.5:8470,' + 'grpc://10.120.27.6:8470,' + 'grpc://10.120.27.7:8470,' + 'grpc://10.120.27.8:8470'), + compat.as_bytes(TPUClusterResolver._gkeEndpoints())) + + tpu_cluster_resolver = TPUClusterResolver() self.assertEqual( compat.as_bytes('grpc://10.120.27.5:8470'), - compat.as_bytes(TPUClusterResolver._gkeMaster())) + compat.as_bytes(tpu_cluster_resolver.master())) + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { + name: 'worker' + tasks { key: 0 value: '10.120.27.5:8470' } + tasks { key: 1 value: '10.120.27.6:8470' } + tasks { key: 2 value: '10.120.27.7:8470' } + tasks { key: 3 value: '10.120.27.8:8470' } + } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] def testDiscoveryUrl(self): diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index 0708d6b7b9f0ba549aea091a265f42890e50d223..4ca7a1b28c6edbda7b5f9c78236fb4437c43afa6 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -18,7 +18,16 @@ cmake_policy(SET CMP0022 NEW) # Options option(tensorflow_VERBOSE "Enable for verbose output" OFF) + +if(WIN32) +# BoringSSL is disabled for windows as it currently doesn't build with +# MSBuild. (Ninja is required.) option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" OFF) +else() +# BoringSSL is enabled for gRPC. +option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" ON) +endif() + option(tensorflow_ENABLE_GRPC_SUPPORT "Enable gRPC support" ON) option(tensorflow_ENABLE_HDFS_SUPPORT "Enable HDFS support" OFF) option(tensorflow_ENABLE_JEMALLOC_SUPPORT "Enable jemalloc support" OFF) @@ -327,40 +336,14 @@ endif() # MKL Support if (tensorflow_ENABLE_MKL_SUPPORT) add_definitions(-DINTEL_MKL -DEIGEN_USE_VML) - if (WIN32) - find_path(MKL_HOME_PLATFORM mkl - PATHS ${MKL_HOME} ${MKL_HOME}/../ ${MKL_HOME}/../../ - $ENV{MKLROOT} $ENV{MKLROOT}/../ $ENV{MKLROOT}/../../ - PATH_SUFFIXES windows) - set(MKL_INCLUDE_DIRS ${MKL_HOME_PLATFORM}/mkl/include) - set(MKL_LINK_DIRS - ${MKL_HOME_PLATFORM}/mkl/lib/intel64 - ${MKL_HOME_PLATFORM}/tbb/lib/intel64/vc_mt - ${MKL_HOME_PLATFORM}/compiler/lib/intel64 - ${MKL_HOME_PLATFORM}/mkl/tools/builder/lib) - set(MKL_REDIST_DLL_DIRS - ${MKL_HOME_PLATFORM}/redist/intel64/mkl - ${MKL_HOME_PLATFORM}/redist/intel64/tbb/vc_mt - ${MKL_HOME_PLATFORM}/redist/intel64/compiler) - list(APPEND tensorflow_EXTERNAL_LIBRARIES - mkl_intel_lp64_dll mkl_sequential_dll mkl_core_dll mkl_rt mkl_cdll_intel64) - endif() - if (UNIX) - # Fix me: complete the path on linux - find_path(MKL_HOME_PLATFORM mkl - HINTS ${MKL_HOME} ${MKL_HOME}/../ ${MKL_HOME}/../../ - $ENV{MKLROOT} $ENV{MKLROOT}/../ $ENV{MKLROOT}/../../ - PATH_SUFFIXES linux) - set(MKL_INCLUDE_DIRS ${MKL_HOME_PLATFORM}/mkl/include) - set(MKL_LINK_DIRS) # incompleted - set(MKL_REDIST_SO_DIRS) # incompleted - endif() - include_directories(${MKL_INCLUDE_DIRS}) - link_directories(${MKL_LINK_DIRS}) + include(mkl) + list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkl_STATIC_LIBRARIES}) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkl_copy_shared_to_destination) + include_directories(${mkl_INCLUDE_DIRS}) if (tensorflow_ENABLE_MKLDNN_SUPPORT) include(mkldnn) list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkldnn_STATIC_LIBRARIES}) - list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn_copy_shared_to_destination) include_directories(${mkldnn_INCLUDE_DIRS}) else (tensorflow_ENABLE_MKLDNN_SUPPORT) add_definitions(-DINTEL_MKL_ML) diff --git a/tensorflow/contrib/cmake/external/double_conversion.cmake b/tensorflow/contrib/cmake/external/double_conversion.cmake index 527ccdc8d887cb4c2e7d2412c99a8bc682568472..5c5adaf5798289fba1c5d0b3f9e0489dc242043e 100644 --- a/tensorflow/contrib/cmake/external/double_conversion.cmake +++ b/tensorflow/contrib/cmake/external/double_conversion.cmake @@ -16,15 +16,15 @@ include (ExternalProject) set(double_conversion_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/double_conversion/src/double_conversion) set(double_conversion_URL https://github.com/google/double-conversion.git) -set(double_conversion_TAG 5664746) +set(double_conversion_TAG 3992066a95b823efc8ccc1baf82a1cfc73f6e9b8) set(double_conversion_BUILD ${double_conversion_INCLUDE_DIR}) set(double_conversion_LIBRARIES ${double_conversion_BUILD}/double-conversion/libdouble-conversion.so) set(double_conversion_INCLUDES ${double_conversion_BUILD}) if(WIN32) - set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/double-conversion/$(Configuration)/double-conversion.lib) + set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/$(Configuration)/double-conversion.lib) else() - set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/double-conversion/libdouble-conversion.a) + set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/libdouble-conversion.a) endif() set(double_conversion_HEADERS diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index 693dc7cd673233b889b35a3f3170b57581da9a9f..b1e64aa55c80ad59cfdc0f4767c0282b4f73367f 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -20,6 +20,10 @@ set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) set(GRPC_TAG d184fa229d75d336aedea0041bd59cb93e7e267f) if(WIN32) + # We use unsecure gRPC because boringssl does not build on windows + set(grpc_TARGET grpc++_unsecure) + set(grpc_DEPENDS protobuf zlib) + set(grpc_SSL_PROVIDER NONE) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") set(grpc_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/grpc++_unsecure.lib @@ -32,9 +36,12 @@ if(WIN32) ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/gpr.lib) endif() else() + set(grpc_TARGET grpc++) + set(grpc_DEPENDS boringssl protobuf zlib) + set(grpc_SSL_PROVIDER module) set(grpc_STATIC_LIBRARIES - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++_unsecure.a - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc_unsecure.a + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++.a + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libaddress_sorting.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/cares/cares/lib/libcares.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a) @@ -44,13 +51,13 @@ add_definitions(-DGRPC_ARES=0) ExternalProject_Add(grpc PREFIX grpc - DEPENDS protobuf zlib + DEPENDS ${grpc_DEPENDS} GIT_REPOSITORY ${GRPC_URL} GIT_TAG ${GRPC_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 BUILD_BYPRODUCTS ${grpc_STATIC_LIBRARIES} - BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc++_unsecure + BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target ${grpc_TARGET} COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc_cpp_plugin INSTALL_COMMAND "" CMAKE_CACHE_ARGS @@ -59,7 +66,7 @@ ExternalProject_Add(grpc -DPROTOBUF_INCLUDE_DIRS:STRING=${PROTOBUF_INCLUDE_DIRS} -DPROTOBUF_LIBRARIES:STRING=${protobuf_STATIC_LIBRARIES} -DZLIB_ROOT:STRING=${ZLIB_INSTALL} - -DgRPC_SSL_PROVIDER:STRING=NONE + -DgRPC_SSL_PROVIDER:STRING=${grpc_SSL_PROVIDER} ) # grpc/src/core/ext/census/tracing.c depends on the existence of openssl/rand.h. diff --git a/tensorflow/contrib/cmake/external/mkl.cmake b/tensorflow/contrib/cmake/external/mkl.cmake new file mode 100644 index 0000000000000000000000000000000000000000..a172e3a41a283359b9a8c823ddcb2b1973b5b3cc --- /dev/null +++ b/tensorflow/contrib/cmake/external/mkl.cmake @@ -0,0 +1,68 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +include (ExternalProject) + +# NOTE: Different from mkldnn.cmake, this file is meant to download mkl libraries +set(mkl_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/include) +set(mkl_BIN_DIRS ${CMAKE_CURRENT_BINARY_DIR}/mkl/bin) +set(mkl_WIN mklml_win_2018.0.3.20180406.zip) # match for v0.14 +set(mkl_MAC mklml_mac_2018.0.3.20180406.tgz) +set(mkl_LNX mklml_lnx_2018.0.3.20180406.tgz) +set(mkl_TAG v0.14) +set(mkl_URL https://github.com/intel/mkl-dnn/releases) + +if (WIN32) + set(mkl_DOWNLOAD_URL ${mkl_URL}/download/${mkl_TAG}/${mkl_WIN}) + list(APPEND mkl_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/mklml.lib) + list(APPEND mkl_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libiomp5md.lib) + list(APPEND mkl_SHARED_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/mklml.dll) + list(APPEND mkl_SHARED_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libiomp5md.dll) +elseif (UNIX) + set(mkl_DOWNLOAD_URL ${mkl_URL}/download/${mkl_TAG}/${mkl_LNX}) + list(APPEND mkl_SHARED_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libiomp5.so) + list(APPEND mkl_SHARED_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libmklml_gnu.so) + list(APPEND mkl_SHARED_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libmklml_intel.so) +elseif (APPLE) + set(mkl_DOWNLOAD_URL ${mkl_URL}/download/${mkl_TAG}/${mkl_MAC}) + #TODO need more information +endif () + +ExternalProject_Add(mkl + PREFIX mkl + URL ${mkl_DOWNLOAD_URL} + DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "") + +# put mkl dynamic libraries in one bin directory +add_custom_target(mkl_create_destination_dir + COMMAND ${CMAKE_COMMAND} -E make_directory ${mkl_BIN_DIRS} + DEPENDS mkl) + +add_custom_target(mkl_copy_shared_to_destination DEPENDS mkl_create_destination_dir) + +foreach(dll_file ${mkl_SHARED_LIBRARIES}) + add_custom_command(TARGET mkl_copy_shared_to_destination PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${dll_file} ${mkl_BIN_DIRS}) +endforeach() diff --git a/tensorflow/contrib/cmake/external/mkldnn.cmake b/tensorflow/contrib/cmake/external/mkldnn.cmake index a639fdee367f060d4c8a79267803da6ffe3dc503..8123ee1f393ab8e3a52f13915ea2a65decc188d9 100644 --- a/tensorflow/contrib/cmake/external/mkldnn.cmake +++ b/tensorflow/contrib/cmake/external/mkldnn.cmake @@ -22,8 +22,11 @@ set(mkldnn_TAG 3063b2e4c943983f6bf5f2fb9a490d4a998cd291) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release/mkldnn.lib) + set(mkldnn_SHARED_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release/mkldnn.dll) + set(mkldnn_BUILD ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release) else() set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/mkldnn.lib) + set(mkldnn_SHARED_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/mkldnn.dll) endif() else() set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/libmkldnn.a) @@ -31,6 +34,7 @@ endif() ExternalProject_Add(mkldnn PREFIX mkldnn + DEPENDS mkl GIT_REPOSITORY ${mkldnn_URL} GIT_TAG ${mkldnn_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" @@ -40,5 +44,11 @@ ExternalProject_Add(mkldnn CMAKE_CACHE_ARGS -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF - -DMKLINC:STRING=${MKL_INCLUDE_DIRS} + -DMKLINC:STRING=${mkl_INCLUDE_DIRS} ) + +# since mkldnn depends on mkl, copy the mkldnn.dll together with mklml.dll to mkl_bin_dirs +add_custom_target(mkldnn_copy_shared_to_destination DEPENDS mkldnn) + +add_custom_command(TARGET mkldnn_copy_shared_to_destination PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${mkldnn_SHARED_LIBRARIES} ${mkl_BIN_DIRS}) diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake index b9d1dd88d4c2d3c9141ba56e14911e06b4d33f7c..6d50a4956b8b525b231d4344b83481f3ab2699e9 100644 --- a/tensorflow/contrib/cmake/external/nsync.cmake +++ b/tensorflow/contrib/cmake/external/nsync.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public) set(nsync_URL https://github.com/google/nsync) -set(nsync_TAG 0559ce013feac8db639ee1bf776aca0325d28777) +set(nsync_TAG 5e8b19a81e5729922629dd505daa651f6ffdf107) set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync) set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install) diff --git a/tensorflow/contrib/cmake/external/protobuf.cmake b/tensorflow/contrib/cmake/external/protobuf.cmake index ab464bc99a43138130bb2758ae28ecef29805c31..f56fb35a0f71250f00b84e5cf94a24682bda6c82 100644 --- a/tensorflow/contrib/cmake/external/protobuf.cmake +++ b/tensorflow/contrib/cmake/external/protobuf.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(PROTOBUF_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src) set(PROTOBUF_URL https://github.com/google/protobuf.git) -set(PROTOBUF_TAG b04e5cba356212e4e8c66c61bbe0c3a20537c5b9) +set(PROTOBUF_TAG v3.6.0) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index fece56c4127de4deebc1404f0eff9747f99ba89f..d530572e91825ed88d09c26a10693288878d09ed 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -35,6 +35,7 @@ tensorflow/python/keras tensorflow/python/keras/applications tensorflow/python/keras/datasets tensorflow/python/keras/engine +tensorflow/python/keras/estimator tensorflow/python/keras/layers tensorflow/python/keras/preprocessing tensorflow/python/keras/utils @@ -129,6 +130,7 @@ tensorflow/contrib/data tensorflow/contrib/data/kernels tensorflow/contrib/data/python tensorflow/contrib/data/python/kernel_tests +tensorflow/contrib/data/python/kernel_tests/serialization tensorflow/contrib/data/python/ops tensorflow/contrib/decision_trees tensorflow/contrib/decision_trees/proto diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 1959ad028a06f3c1ff6a658d656155541891fd13..786ea05c744167ad52d52cc73328bd8c25d78c3e 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -742,30 +742,113 @@ endforeach(api_init_file) set(api_init_list_file "${tensorflow_source_dir}/api_init_files_list.txt") file(WRITE "${api_init_list_file}" "${api_init_files}") +# Run create_python_api.py to generate __init__.py files. + +### TODO +# In order to download and compile MKL/MKL-DNN automatically in cmake script, mkl-built libraries should be added to system path +# to be loaded by python executor. However `add_custom_command` has an issue with `COMMAND ${CMAKE_COMMAND} -E env PATH=`, where +# arguments of multiple paths (such as D:/;D:/mkl) will be parsed in to seperate string without semicolon and that command fail to +# recongnize paths. As CUDA isn't built with MKL, the MKL built directory is the only path to this command to work around that issue. +# To not override the CUDA and system path in other circumstances, `if-else` branch used here to handle this problem, +# and should be removed if the path issue can be resolved. +### + +if (tensorflow_ENABLE_MKL_SUPPORT) + # add mkl dist dlls to system path for python + # TODO: In current cmake version, PY_RUNTIME_ENV behaves strange with multiple paths, + # so we have to specify only one path in it to work around the issue. We need this if/else + # to protect overwriting CUDA environments + set(PY_RUNTIME_ENV ${mkl_BIN_DIRS}) + add_custom_command( + OUTPUT ${api_init_files} + DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops + + # tensorflow/__init__.py depends on files generated in this step. So, remove it while + # this step is running since the files aren't there yet. + COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + + # Run create_python_api.py to generate API init files. + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python PATH=${PY_RUNTIME_ENV} ${PYTHON_EXECUTABLE} + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" + "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" + "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" + "--package=tensorflow.python" + "--apiname=tensorflow" + "${api_init_list_file}" + + COMMENT "Generating __init__.py files for Python API." + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" + VERBATIM + ) +else (tensorflow_ENABLE_MKL_SUPPORT) + add_custom_command( + OUTPUT ${api_init_files} + DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops + + # tensorflow/__init__.py depends on files generated in this step. So, remove it while + # this step is running since the files aren't there yet. + COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + + # Run create_python_api.py to generate API init files. + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" + "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" + "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" + "--package=tensorflow.python" + "--apiname=tensorflow" + "${api_init_list_file}" + + COMMENT "Generating __init__.py files for Python API." + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" + ) +endif (tensorflow_ENABLE_MKL_SUPPORT) + +add_custom_target(tf_python_api SOURCES ${api_init_files}) +add_dependencies(tf_python_api tf_python_ops) + +# TODO(mikecase): This can be removed once tf.estimator is moved +# out of TensorFlow. +######################################################## +# Generate API __init__.py files for tf.estimator. +######################################################## + +# Parse tensorflow/tools/api/generator/BUILD to get list of generated files. +FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/api_gen.bzl api_generator_BUILD_text) +STRING(REGEX MATCH "# BEGIN GENERATED ESTIMATOR FILES.*# END GENERATED ESTIMATOR FILES" api_init_files_text ${api_generator_BUILD_text}) +string(REPLACE "# BEGIN GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text}) +string(REPLACE "# END GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text}) +string(REPLACE "," ";" api_init_files_list ${api_init_files_text}) + +set(api_init_files "") +foreach(api_init_file ${api_init_files_list}) + string(STRIP "${api_init_file}" api_init_file) + if(api_init_file) + string(REPLACE "\"" "" api_init_file "${api_init_file}") # Remove quotes + list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/estimator/api/${api_init_file}") + endif() +endforeach(api_init_file) +set(estimator_api_init_list_file "${tensorflow_source_dir}/estimator_api_init_files_list.txt") +file(WRITE "${estimator_api_init_list_file}" "${api_init_files}") + # Run create_python_api.py to generate __init__.py files. add_custom_command( OUTPUT ${api_init_files} DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops - # tensorflow/__init__.py depends on files generated in this step. So, remove it while - # this step is running since the files aren't there yet. - COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py - # Run create_python_api.py to generate API init files. COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" - "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" - "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" - "${api_init_list_file}" + "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/estimator/api" + "--package=tensorflow.python.estimator" + "--apiname=estimator" + "${estimator_api_init_list_file}" COMMENT "Generating __init__.py files for Python API." WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" ) -add_custom_target(tf_python_api SOURCES ${api_init_files}) -add_dependencies(tf_python_api tf_python_ops) - - +add_custom_target(estimator_python_api SOURCES ${api_init_files}) +add_dependencies(estimator_python_api tf_python_ops) ############################################################ # Build a PIP package containing the TensorFlow runtime. ############################################################ @@ -776,6 +859,7 @@ add_dependencies(tf_python_build_pip_package tf_python_touchup_modules tf_python_ops tf_python_api + estimator_python_api tf_extension_ops) # Fix-up Python files that were not included by the add_python_module() macros. diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake index 38f40452b533fdc0dba6ac686a0ff43a2ef13cb8..fdf522f1fd90ffc64acbe82381ef57a389645d61 100644 --- a/tensorflow/contrib/cmake/tf_shared_lib.cmake +++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake @@ -145,3 +145,8 @@ install(DIRECTORY ${tensorflow_source_dir}/third_party/eigen3/ # unsupported Eigen directory install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/unsupported/Eigen/ DESTINATION include/unsupported/Eigen) +# mkl +if (tensorflow_ENABLE_MKL_SUPPORT) + install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/include/ + DESTINATION include/mkl) +endif (tensorflow_ENABLE_MKL_SUPPORT) diff --git a/tensorflow/contrib/constrained_optimization/README.md b/tensorflow/contrib/constrained_optimization/README.md index c65a150464efc1e77419040f66f36fc6756325aa..cb1dd7d836ae11700b2ffaaff4fda5b7f943f87d 100644 --- a/tensorflow/contrib/constrained_optimization/README.md +++ b/tensorflow/contrib/constrained_optimization/README.md @@ -46,7 +46,7 @@ document. Imagine that we want to constrain the recall of a binary classifier to be at least 90%. Since the recall is proportional to the number of true positive classifications, which itself is a sum of indicator functions, this constraint -is non-differentible, and therefore cannot be used in a problem that will be +is non-differentiable, and therefore cannot be used in a problem that will be optimized using a (stochastic) gradient-based algorithm. For this and similar problems, TFCO supports so-called *proxy constraints*, diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py index 04014ab4aebd6d9cd70653c53f9361320e803329..3791dae8d7f6b03bc1115bca97811dfc4775c45b 100644 --- a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py +++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py @@ -169,8 +169,8 @@ def _project_stochastic_matrix_wrt_euclidean_norm(matrix): del old_inactive # Needed by the condition, but not the body. iteration += 1 scale = (1.0 - standard_ops.reduce_sum( - matrix, axis=0, keep_dims=True)) / standard_ops.maximum( - 1.0, standard_ops.reduce_sum(inactive, axis=0, keep_dims=True)) + matrix, axis=0, keepdims=True)) / standard_ops.maximum( + 1.0, standard_ops.reduce_sum(inactive, axis=0, keepdims=True)) matrix += scale * inactive new_inactive = standard_ops.to_float(matrix > 0) matrix *= new_inactive @@ -206,10 +206,10 @@ def _project_log_stochastic_matrix_wrt_kl_divergence(log_matrix): # For numerical reasons, make sure that the largest matrix element is zero # before exponentiating. - log_matrix -= standard_ops.reduce_max(log_matrix, axis=0, keep_dims=True) + log_matrix -= standard_ops.reduce_max(log_matrix, axis=0, keepdims=True) log_matrix -= standard_ops.log( standard_ops.reduce_sum( - standard_ops.exp(log_matrix), axis=0, keep_dims=True)) + standard_ops.exp(log_matrix), axis=0, keepdims=True)) return log_matrix diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 1af1ed08b53ee04367eb316d5c9caa0216f2e88d..2a4cf877f0f1c5dcd738bb64c1c9a4c1f3c4560d 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -25,7 +25,10 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@Counter @@CheckpointInputPipelineHook @@CsvDataset +@@RandomDataset +@@Reducer @@SqlDataset +@@TFRecordWriter @@assert_element_shape @@batch_and_drop_remainder @@ -33,11 +36,15 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@choose_from_datasets @@dense_to_sparse_batch @@enumerate_dataset + +@@get_single_element +@@group_by_reducer @@group_by_window @@ignore_errors @@make_batched_features_dataset @@make_csv_dataset @@make_saveable_from_iterator + @@map_and_batch @@padded_batch_and_drop_remainder @@parallel_interleave @@ -50,8 +57,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@sliding_window_batch @@sloppy_interleave @@unbatch - -@@get_single_element +@@unique """ from __future__ import absolute_import @@ -71,13 +77,17 @@ from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset from tensorflow.contrib.data.python.ops.error_ops import ignore_errors from tensorflow.contrib.data.python.ops.get_single_element import get_single_element from tensorflow.contrib.data.python.ops.grouping import bucket_by_sequence_length +from tensorflow.contrib.data.python.ops.grouping import group_by_reducer from tensorflow.contrib.data.python.ops.grouping import group_by_window +from tensorflow.contrib.data.python.ops.grouping import Reducer +from tensorflow.contrib.data.python.ops.interleave_ops import choose_from_datasets from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datasets from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device +from tensorflow.contrib.data.python.ops.random_ops import RandomDataset from tensorflow.contrib.data.python.ops.readers import CsvDataset from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset from tensorflow.contrib.data.python.ops.readers import make_csv_dataset @@ -87,6 +97,8 @@ from tensorflow.contrib.data.python.ops.resampling import rejection_resample from tensorflow.contrib.data.python.ops.scan_ops import scan from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch +from tensorflow.contrib.data.python.ops.unique import unique +from tensorflow.contrib.data.python.ops.writers import TFRecordWriter # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc index e88ad3dc32003ece2b8810661cd4db374196561c..4657807785d58727d34f37172bd30c56a5b7cde6 100644 --- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -236,7 +236,7 @@ class CSVDatasetOp : public DatasetOpKernel { size_t num_parsed = 0; size_t num_selected_parsed = 0; - Status result = Status::OK(); + Status result; while (!end_of_record) { // Read till we reach \n, \r or EOF bool include = @@ -329,6 +329,7 @@ class CSVDatasetOp : public DatasetOpKernel { size_t start = pos_; pos_++; // Starting quotation mark + Status parse_result; while (true) { // Each iter reads 1 char, filling buffer if necessary if (pos_ >= buffer_.size()) { Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); @@ -351,8 +352,9 @@ class CSVDatasetOp : public DatasetOpKernel { if (errors::IsOutOfRange(s)) { // This was the last field. We are done *end_of_record = true; - return QuotedFieldToOutput(ctx, StringPiece(), out_tensors, - earlier_pieces, include); + parse_result.Update(QuotedFieldToOutput( + ctx, StringPiece(), out_tensors, earlier_pieces, include)); + return parse_result; } else if (!s.ok()) { return s; } @@ -361,20 +363,24 @@ class CSVDatasetOp : public DatasetOpKernel { char next = buffer_[pos_]; pos_++; if (next == dataset()->delim_) { - return QuotedFieldToOutput( + parse_result.Update(QuotedFieldToOutput( ctx, StringPiece(&buffer_[start], pos_ - 1 - start), - out_tensors, earlier_pieces, include); + out_tensors, earlier_pieces, include)); + return parse_result; } else if (next == '\n' || next == '\r') { *end_of_record = true; - Status s = QuotedFieldToOutput( + parse_result.Update(QuotedFieldToOutput( ctx, StringPiece(&buffer_[start], pos_ - 1 - start), - out_tensors, earlier_pieces, include); + out_tensors, earlier_pieces, include)); if (next == '\r') SkipNewLineIfNecessary(); - return s; + return parse_result; } else if (next != '"') { - return errors::InvalidArgument( - "Quote inside a string has to be escaped by another quote"); + // Take note of the error, but keep going to end of field. + include = false; // So we don't get funky errors when trying to + // unescape the quotes. + parse_result.Update(errors::InvalidArgument( + "Quote inside a string has to be escaped by another quote")); } } else { @@ -454,6 +460,8 @@ class CSVDatasetOp : public DatasetOpKernel { EXCLUSIVE_LOCKS_REQUIRED(mu_) { std::vector earlier_pieces; size_t start = pos_; + Status parse_result; + while (true) { // Each iter reads 1 char, filling buffer if necessary if (pos_ >= buffer_.size()) { Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); @@ -461,9 +469,10 @@ class CSVDatasetOp : public DatasetOpKernel { if (errors::IsOutOfRange(s)) { // Whatever we have is the last field of the last record *end_of_record = true; - return UnquotedFieldToOutput( + parse_result.Update(UnquotedFieldToOutput( ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, - earlier_pieces, include); + earlier_pieces, include)); + return parse_result; } else if (!s.ok()) { return s; // Surface all other errors to caller } @@ -472,66 +481,33 @@ class CSVDatasetOp : public DatasetOpKernel { char ch = buffer_[pos_]; if (ch == dataset()->delim_) { - Status s = UnquotedFieldToOutput( + parse_result.Update(UnquotedFieldToOutput( ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, - earlier_pieces, include); + earlier_pieces, include)); pos_++; - return s; + return parse_result; } if (ch == '\n' || ch == '\r') { // need special case to skip over first \n of record if the line // breaks are \r\n - Status s = UnquotedFieldToOutput( + parse_result.Update(UnquotedFieldToOutput( ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, - earlier_pieces, include); + earlier_pieces, include)); *end_of_record = true; pos_++; if (ch == '\r') SkipNewLineIfNecessary(); - return s; + return parse_result; } if (dataset()->use_quote_delim_ && ch == '"') { - // Advance pos_ to the next field anyway so that we can ignore - // errors gracefully if required. The caller of this will be able to - // call ParseOneField and continue with the rest of the record. - AdvanceToNextField(end_of_record); - return errors::InvalidArgument( - "Unquoted fields cannot have quotes inside"); + // Take note of the error, but keep going to end of field. + parse_result.Update(errors::InvalidArgument( + "Unquoted fields cannot have quotes inside")); } // Otherwise, go to next character pos_++; } } - // Advances pos_ to the start of the next field, as delimited by delim, - // CRLF, or EOF, ignoring errors, and not keeping track of characters in - // the current field. - void AdvanceToNextField(bool* end_of_record) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - while (true) { - if (pos_ >= buffer_.size()) { - Status s = FillBuffer(&buffer_); - pos_ = 0; - if (!s.ok()) { - *end_of_record = true; - return; - } - } - - char ch = buffer_[pos_]; - pos_++; - - if (ch == dataset()->delim_) { - return; - } - - if (ch == '\n' || ch == '\r') { - *end_of_record = true; - if (ch == '\r') SkipNewLineIfNecessary(); - return; - } - } - } - Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) { result->clear(); Status s = input_stream_->ReadNBytes(dataset()->buffer_size_, result); diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc index 3dfc3741c2b040dd5be3223c24f0715ba3be4248..141706f393b076d9f55898ca4bdbe7438f7c3625 100644 --- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/util/work_sharder.h" namespace tensorflow { namespace { @@ -24,19 +25,32 @@ namespace { class ThreadPoolResource : public ResourceBase { public: ThreadPoolResource(Env* env, const ThreadOptions& thread_options, - const string& name, int num_threads, bool low_latency_hint) - : thread_pool_(env, thread_options, name, num_threads, low_latency_hint) { - } + const string& name, int num_threads, bool low_latency_hint, + int max_intra_op_parallelism) + : thread_pool_(env, thread_options, name, num_threads, low_latency_hint), + max_intra_op_parallelism_(max_intra_op_parallelism) {} // Schedules fn() for execution in the pool of threads. void Schedule(std::function fn) { - thread_pool_.Schedule(std::move(fn)); + if (max_intra_op_parallelism_ < 0) { + thread_pool_.Schedule(std::move(fn)); + } else { + thread_pool_.Schedule(std::bind( + [this](std::function bound_fn) { + // TODO(mrry): Consider moving this thread-local configuration to + // the threads themselves. + ScopedPerThreadMaxParallelism scope(max_intra_op_parallelism_); + bound_fn(); + }, + std::move(fn))); + } } string DebugString() override { return "ThreadPoolResource"; } private: thread::ThreadPool thread_pool_; + const int max_intra_op_parallelism_; }; // Creates a handle to a ThreadPool resource. Note that we don't use @@ -48,6 +62,8 @@ class ThreadPoolHandleOp : public OpKernel { explicit ThreadPoolHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("display_name", &display_name_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("num_threads", &num_threads_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_intra_op_parallelism", + &max_intra_op_parallelism_)); OP_REQUIRES( ctx, num_threads_ > 0, errors::InvalidArgument("`num_threads` must be greater than zero.")); @@ -78,7 +94,7 @@ class ThreadPoolHandleOp : public OpKernel { EXCLUSIVE_LOCKS_REQUIRED(mu_) { *ret = new ThreadPoolResource( ctx->env(), {}, display_name_, - num_threads_, + num_threads_, max_intra_op_parallelism_, false /* low_latency_hint */); return Status::OK(); })); @@ -95,6 +111,7 @@ class ThreadPoolHandleOp : public OpKernel { bool initialized_ GUARDED_BY(mu_) = false; string display_name_; int num_threads_; + int max_intra_op_parallelism_; }; class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc index f271d269ab1b9339de4657e459dcbbd462890f0a..f48e96509a193266d5d43453291c5e463f088117 100644 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -158,6 +158,7 @@ REGISTER_OP("ThreadPoolHandle") .Output("handle: resource") .SetShapeFn(shape_inference::ScalarShape) .Attr("num_threads: int") + .Attr("max_intra_op_parallelism: int = 1") .Attr("display_name: string") .Attr("container: string = ''") .Attr("shared_name: string = ''") @@ -166,6 +167,8 @@ Creates a custom thread pool with the given number of threads. handle: A resource that can be consumed by one or more ThreadPoolDataset ops. num_threads: The number of threads in the thread pool. +max_intra_op_parallelism: The maximum degree of parallelism to use within + operations that execute on this threadpool. display_name: A human-readable name for the threads that may be visible in some visualizations. )doc"); diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index ba707d8d6e466561442f48e5dd7e8bdee20fb0f7..d81654e039c53e5b9434288352ef1b2416a4b7e8 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -4,7 +4,7 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "cuda_py_test", "py_test", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test", "py_test") py_test( name = "batch_dataset_op_test", @@ -16,20 +16,23 @@ py_test( "no_pip", ], deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:script_ops", + "//tensorflow/python:session", "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) @@ -39,7 +42,6 @@ py_test( srcs = ["bucketing_test.py"], srcs_version = "PY2AND3", deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:grouping", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -48,24 +50,33 @@ py_test( "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) py_test( - name = "concatenate_dataset_op_test", + name = "csv_dataset_op_test", size = "small", - srcs = ["concatenate_dataset_op_test.py"], + srcs = ["csv_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ - ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:error_ops", + "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:tensor_shape", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python:session", + "//tensorflow/python/data/ops:readers", "//third_party/py/numpy", ], ) @@ -80,104 +91,44 @@ py_test( "nomac", # b/62040583 ], deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:batching", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:session", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", - "//third_party/py/numpy", ], ) -py_library( - name = "dataset_serialization_test", - srcs = [ - "dataset_serialization_test_base.py", - ], +py_test( + name = "directed_interleave_dataset_test", + size = "medium", + srcs = ["directed_interleave_dataset_test.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/contrib/data/python/ops:interleave_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:platform", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:variables", - "//tensorflow/python/data/ops:iterator_ops", - "//third_party/py/numpy", - ], -) - -py_test( - name = "csv_dataset_op_test", - size = "small", - srcs = ["csv_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:error_ops", - "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/python:random_seed", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) py_test( - name = "filter_dataset_op_test", + name = "get_single_element_test", size = "small", - srcs = ["filter_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - "optonly", - ], + srcs = ["get_single_element_test.py"], deps = [ - ":dataset_serialization_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:functional_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - ], -) - -tf_py_test( - name = "flat_map_dataset_op_test", - size = "medium", - srcs = ["flat_map_dataset_op_test.py"], - additional_deps = [ - ":dataset_serialization_test", - "//third_party/py/numpy", - "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:get_single_element", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:function", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:session", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", ], - grpc_enabled = True, - tags = ["no_pip"], ) py_test( @@ -192,10 +143,8 @@ py_test( "notap", ], deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:interleave_ops", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", @@ -203,43 +152,28 @@ py_test( "//tensorflow/python:script_ops", "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", + "@six_archive//:six", ], ) py_test( - name = "directed_interleave_dataset_test", - size = "medium", - srcs = ["directed_interleave_dataset_test.py"], + name = "iterator_ops_test", + size = "small", + srcs = ["iterator_ops_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ - ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:interleave_ops", - "//tensorflow/python:client", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:training", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - ], -) - -tf_py_test( - name = "get_single_element_test", - size = "small", - srcs = ["get_single_element_test.py"], - additional_deps = [ - "//third_party/py/numpy", - "//tensorflow/contrib/data/python/ops:get_single_element", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python:array_ops", + "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_test_lib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:model_fn", ], ) @@ -254,27 +188,13 @@ py_test( "optonly", ], deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:error_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", - "//tensorflow/python:function", - "//tensorflow/python:functional_ops", "//tensorflow/python:io_ops", - "//tensorflow/python:lookup_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:script_ops", - "//tensorflow/python:sparse_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:string_ops", "//tensorflow/python:util", - "//tensorflow/python:variable_scope", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -286,23 +206,30 @@ py_test( srcs = ["optimize_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:optimization", - "//tensorflow/python:platform", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", "//tensorflow/python/data/ops:dataset_ops", ], ) -py_test( - name = "prefetch_dataset_op_test", +cuda_py_test( + name = "prefetching_ops_test", size = "small", - srcs = ["prefetch_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":dataset_serialization_test", - "//tensorflow/python:platform", + srcs = ["prefetching_ops_test.py"], + additional_deps = [ + "//tensorflow/contrib/data/python/ops:prefetching_ops", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:function", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", ], ) @@ -312,46 +239,60 @@ py_test( srcs = ["range_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:counter", "//tensorflow/contrib/data/python/ops:enumerate_ops", - "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", - "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:io_ops", - "//tensorflow/python:parsing_ops", "//tensorflow/python:tensor_shape", - "//tensorflow/python:variables", "//tensorflow/python/data/ops:dataset_ops", ], ) +py_library( + name = "reader_dataset_ops_test_base", + testonly = 1, + srcs = [ + "reader_dataset_ops_test_base.py", + ], + srcs_version = "PY2AND3", + visibility = [ + "//tensorflow/contrib/data/python/kernel_tests:__pkg__", + "//tensorflow/contrib/data/python/kernel_tests/serialization:__pkg__", + ], + deps = [ + "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:lib", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/ops:readers", + ], +) + py_test( name = "reader_dataset_ops_test", size = "medium", srcs = ["reader_dataset_ops_test.py"], - shard_count = 4, srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ - ":dataset_serialization_test", + ":reader_dataset_ops_test_base", "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", - "//tensorflow/python:lib", "//tensorflow/python:parsing_ops", "//tensorflow/python:string_ops", - "//tensorflow/python:util", - "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/ops:readers", "//third_party/py/numpy", ], ) @@ -378,6 +319,7 @@ py_test( "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", + "@six_archive//:six", ], ) @@ -388,13 +330,14 @@ py_test( srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:scan_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/eager:context", "//third_party/py/numpy", @@ -402,55 +345,55 @@ py_test( ) py_test( - name = "sequence_dataset_op_test", + name = "shuffle_dataset_op_test", size = "medium", - srcs = ["sequence_dataset_op_test.py"], + srcs = ["shuffle_dataset_op_test.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], + tags = [ + "no_pip", + "optonly", + ], deps = [ - ":dataset_serialization_test", - "//tensorflow/python:array_ops", + "//tensorflow/contrib/data/python/ops:shuffle_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) py_test( - name = "serialization_integration_test", + name = "slide_dataset_op_test", size = "small", - srcs = ["serialization_integration_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], + srcs = ["slide_dataset_op_test.py"], deps = [ - "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/contrib/data/python/ops:sliding", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", ], ) -py_test( - name = "shuffle_dataset_op_test", - size = "medium", - srcs = ["shuffle_dataset_op_test.py"], +py_library( + name = "sql_dataset_op_test_base", + srcs = ["sql_dataset_op_test_base.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], + visibility = [ + "//tensorflow/contrib/data/python/kernel_tests:__pkg__", + "//tensorflow/contrib/data/python/kernel_tests/serialization:__pkg__", + ], deps = [ - ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:shuffle_ops", + "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - "//third_party/py/numpy", + "@org_sqlite//:python", ], ) @@ -459,14 +402,12 @@ py_test( size = "small", srcs = ["sql_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ - ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/python:array_ops", + ":sql_dataset_op_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "@org_sqlite//:python", ], ) @@ -477,11 +418,15 @@ py_test( srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ - ":dataset_serialization_test", + ":reader_dataset_ops_test_base", "//tensorflow/contrib/data/python/ops:stats_ops", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", ], ) @@ -495,8 +440,12 @@ py_test( "//tensorflow/contrib/data/python/ops:threadpool", "//tensorflow/contrib/data/python/ops:unique", "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:script_ops", "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) @@ -507,87 +456,27 @@ py_test( srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ - ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:unique", - "//tensorflow/contrib/stateless", - "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", ], ) py_test( - name = "zip_dataset_op_test", - size = "small", - srcs = ["zip_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":dataset_serialization_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python/data/ops:dataset_ops", - "//third_party/py/numpy", - ], -) - -cuda_py_test( - name = "prefetching_ops_test", - size = "small", - srcs = ["prefetching_ops_test.py"], - additional_deps = [ - "//tensorflow/contrib/data/python/ops:prefetching_ops", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:function", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - ], -) - -tf_py_test( - name = "slide_dataset_op_test", - size = "small", - srcs = ["slide_dataset_op_test.py"], - additional_deps = [ - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:sliding", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:math_ops", - "//tensorflow/python:sparse_tensor", - "//third_party/py/numpy", - ], -) - -tf_py_test( name = "writer_ops_test", size = "small", srcs = ["writer_ops_test.py"], - additional_deps = [ + deps = [ "//tensorflow/contrib/data/python/ops:writers", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:io_ops", "//tensorflow/python:lib", - "//tensorflow/python:tensor_shape", "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:readers", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index b5fbc45ad3d8d262c1c79b5723ffeb38ff6a34c2..4c6023230846338e7e7662cd908c2ab9fd2e5483 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -20,9 +20,9 @@ from __future__ import print_function import math import time +from absl.testing import parameterized import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import batching from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops @@ -40,7 +40,7 @@ from tensorflow.python.platform import test from tensorflow.python.util import compat -class BatchDatasetTest(test.TestCase): +class BatchDatasetTest(test.TestCase, parameterized.TestCase): def assertSparseValuesEqual(self, a, b): self.assertAllEqual(a.indices, b.indices) @@ -427,9 +427,13 @@ class BatchDatasetTest(test.TestCase): self.assertEqual([None], dataset.output_shapes[1][0].as_list()) self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) - def _testMapAndBatchDatasetHelper(self, - num_parallel_calls=None, - num_parallel_batches=None): + @parameterized.named_parameters( + ("default", None, None), + ("sequential_calls", 1, None), + ("parallel_calls", 2, None), + ("parallel_batches", None, 10), + ) + def testMapAndBatch(self, num_parallel_calls, num_parallel_batches): """Test a dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size). @@ -500,19 +504,11 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) - def testMapAndBatch(self): - return self._testMapAndBatchDatasetHelper() - - def testMapAndBatchWithParallelBatches(self): - return self._testMapAndBatchDatasetHelper(num_parallel_batches=10) - - def testMapAndBatchWithSequentialCalls(self): - return self._testMapAndBatchDatasetHelper(num_parallel_calls=1) - - def testMapAndBatchWithParallelCalls(self): - return self._testMapAndBatchDatasetHelper(num_parallel_calls=2) - - def _testMapAndBatchPartialBatchHelper(self, drop_remainder=False): + @parameterized.named_parameters( + ("even", False), + ("uneven", True), + ) + def testMapAndBatchPartialBatch(self, drop_remainder): iterator = ( dataset_ops.Dataset.range(10).apply( batching.map_and_batch( @@ -532,12 +528,6 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) - def testMapAndBatchPartialBatch(self): - return self._testMapAndBatchPartialBatchHelper() - - def testMapAndBatchPartialBatchDropRemainder(self): - return self._testMapAndBatchPartialBatchHelper(drop_remainder=True) - def testMapAndBatchYieldsPartialBatch(self): iterator = (dataset_ops.Dataset.range(10) .apply(batching.map_and_batch( @@ -614,7 +604,7 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testMapAndBatchDatasetFails(self): + def testMapAndBatchFails(self): """Test a dataset that maps a TF function across its input elements.""" dataset = dataset_ops.Dataset.from_tensors( array_ops.check_numerics( @@ -628,7 +618,7 @@ class BatchDatasetTest(test.TestCase): with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(init_op, feed_dict={batch_size: 14}) - def testMapAndBatchDatasetShapeMismatch(self): + def testMapAndBatchShapeMismatch(self): """Test a dataset that maps a TF function across its input elements.""" def generator(): @@ -652,174 +642,6 @@ class BatchDatasetTest(test.TestCase): sess.run(get_next) -class BatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): - components = ( - np.arange(tensor_slice_len), - np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis], - np.array(multiplier) * np.arange(tensor_slice_len)) - - return dataset_ops.Dataset.from_tensor_slices(components).batch(batch_size) - - def testCore(self): - tensor_slice_len = 8 - batch_size = 2 - num_outputs = tensor_slice_len // batch_size - self.run_core_tests( - lambda: self.build_dataset(15.0, tensor_slice_len, batch_size), - lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), - num_outputs) - - def _build_dataset_dense_to_sparse(self, components): - return dataset_ops.Dataset.from_tensor_slices(components).map( - lambda x: array_ops.fill([x], x)).apply( - batching.dense_to_sparse_batch(4, [12])) - - def testDenseToSparseBatchDatasetCore(self): - components = np.random.randint(5, size=(40,)).astype(np.int32) - diff_comp = np.random.randint(2, size=(100,)).astype(np.int32) - - num_outputs = len(components) // 4 - self.run_core_tests(lambda: self._build_dataset_dense_to_sparse(components), - lambda: self._build_dataset_dense_to_sparse(diff_comp), - num_outputs) - - def _sparse(self, i): - return sparse_tensor.SparseTensorValue( - indices=[[0]], values=(i * [1]), dense_shape=[1]) - - def _build_dataset_sparse(self, batch_size=5): - return dataset_ops.Dataset.range(10).map(self._sparse).batch(batch_size) - - def testSparseCore(self): - self.run_core_tests(self._build_dataset_sparse, - lambda: self._build_dataset_sparse(2), 2) - - def _build_dataset_nested_sparse(self): - return dataset_ops.Dataset.range(10).map(self._sparse).batch(5).batch(2) - - def testNestedSparseCore(self): - self.run_core_tests(self._build_dataset_nested_sparse, None, 1) - - -class UnbatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): - components = ( - np.arange(tensor_slice_len), - np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis], - np.array(multiplier) * np.arange(tensor_slice_len)) - - return dataset_ops.Dataset.from_tensor_slices(components).batch( - batch_size).apply(batching.unbatch()) - - def testCore(self): - tensor_slice_len = 8 - batch_size = 2 - num_outputs = tensor_slice_len - self.run_core_tests( - lambda: self.build_dataset(15.0, tensor_slice_len, batch_size), - lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), - num_outputs) - - -class MapAndBatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def testNumParallelBatches(self): - range_size = 11 - num_repeats = 2 - batch_size = 5 - total_outputs = range_size * num_repeats - num_outputs_drop_remainder = total_outputs // batch_size - num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) - num_parallel_batches = 2 - - def build_ds(range_start, drop_remainder=False): - - def _map_fn(x): - return math_ops.square(x) - - return dataset_ops.Dataset.range( - range_start, range_start + range_size).repeat(num_repeats).apply( - batching.map_and_batch( - map_func=_map_fn, - batch_size=batch_size, - num_parallel_batches=num_parallel_batches, - drop_remainder=drop_remainder)) - - self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), - num_outputs_keep_remainder) - self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), - num_outputs_drop_remainder) - - def testNumParallelCalls(self): - range_size = 11 - num_repeats = 2 - batch_size = 5 - total_outputs = range_size * num_repeats - num_outputs_drop_remainder = total_outputs // batch_size - num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) - num_parallel_calls = 7 - - def build_ds(range_start, drop_remainder=False): - - def _map_fn(x): - return math_ops.square(x) - - return dataset_ops.Dataset.range( - range_start, range_start + range_size).repeat(num_repeats).apply( - batching.map_and_batch( - map_func=_map_fn, - batch_size=batch_size, - num_parallel_calls=num_parallel_calls, - drop_remainder=drop_remainder)) - - self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), - num_outputs_keep_remainder) - self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), - num_outputs_drop_remainder) - - -class PaddedBatchDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def testPaddedBatch(self): - - def build_dataset(seq_lens): - return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( - lambda x: array_ops.fill([x], x)).padded_batch( - 4, padded_shapes=[-1]) - - seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) - seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) - self.run_core_tests(lambda: build_dataset(seq_lens1), - lambda: build_dataset(seq_lens2), 8) - - def testPaddedBatchNonDefaultPadding(self): - - def build_dataset(seq_lens): - - def fill_tuple(x): - filled = array_ops.fill([x], x) - return (filled, string_ops.as_string(filled)) - - padded_shape = [-1] - return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( - fill_tuple).padded_batch( - 4, - padded_shapes=(padded_shape, padded_shape), - padding_values=(-1, "")) - - seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) - seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) - self.run_core_tests(lambda: build_dataset(seq_lens1), - lambda: build_dataset(seq_lens2), 8) - - class RestructuredDatasetTest(test.TestCase): def test_assert_element_shape(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index bd3e034211c4aa454e4f8f6b09f14935d7a3b35c..c5d2edbbc6240675dd7add4744de0966cf1dfe13 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -21,7 +21,6 @@ import random import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import grouping from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op @@ -68,7 +67,7 @@ class GroupByReducerTest(test.TestCase): reducer = grouping.Reducer( init_func=lambda _: (0.0, 0.0), reduce_func=reduce_fn, - finalize_func=lambda x: x[0]) + finalize_func=lambda x, _: x) for i in range(1, 11): dataset = dataset_ops.Dataset.range(2 * i).apply( grouping.group_by_reducer( @@ -121,7 +120,7 @@ class GroupByReducerTest(test.TestCase): reducer = grouping.Reducer( init_func=lambda x: ([0], 1), reduce_func=reduce_fn, - finalize_func=lambda x: x) + finalize_func=lambda x, y: (x, y)) for i in range(1, 11): dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply( @@ -177,38 +176,6 @@ class GroupByReducerTest(test.TestCase): grouping.group_by_reducer(lambda _: "wrong", reducer)) -class GroupByReducerSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_dataset(self, components): - reducer = grouping.Reducer( - init_func=lambda _: np.int64(0), - reduce_func=lambda x, y: x + y, - finalize_func=lambda x: x) - - return dataset_ops.Dataset.from_tensor_slices(components).apply( - grouping.group_by_reducer(lambda x: x % 5, reducer)) - - def testCoreGroupByReducer(self): - components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64) - self.verify_unused_iterator( - lambda: self._build_dataset(components), 5, verify_exhausted=True) - self.verify_init_before_restore( - lambda: self._build_dataset(components), 5, verify_exhausted=True) - self.verify_multiple_breaks( - lambda: self._build_dataset(components), 5, verify_exhausted=True) - self.verify_reset_restored_iterator( - lambda: self._build_dataset(components), 5, verify_exhausted=True) - self.verify_restore_in_empty_graph( - lambda: self._build_dataset(components), 5, verify_exhausted=True) - diff_components = np.array([5, 4, 3, 2, 1, 0], dtype=np.int64) - self.verify_restore_in_modified_graph( - lambda: self._build_dataset(components), - lambda: self._build_dataset(diff_components), - 5, - verify_exhausted=True) - - class GroupByWindowTest(test.TestCase): def testSimple(self): @@ -353,34 +320,6 @@ class GroupByWindowTest(test.TestCase): self.assertEqual(len(components), sum(counts)) -class GroupByWindowSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_dataset(self, components): - return dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply( - grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4)) - - def testCoreGroupByWindow(self): - components = np.array( - [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64) - self.verify_unused_iterator( - lambda: self._build_dataset(components), 12, verify_exhausted=False) - self.verify_init_before_restore( - lambda: self._build_dataset(components), 12, verify_exhausted=False) - self.verify_multiple_breaks( - lambda: self._build_dataset(components), 12, verify_exhausted=False) - self.verify_reset_restored_iterator( - lambda: self._build_dataset(components), 12, verify_exhausted=False) - self.verify_restore_in_empty_graph( - lambda: self._build_dataset(components), 12, verify_exhausted=False) - diff_components = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64) - self.verify_restore_in_modified_graph( - lambda: self._build_dataset(components), - lambda: self._build_dataset(diff_components), - 12, - verify_exhausted=False) - - # NOTE(mrry): These tests are based on the tests in bucket_ops_test.py. # Currently, they use a constant batch size, though should be made to use a # different batch size per key. diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py index 74b90ec7d1617d221888d1e1c56cf594c367ddf9..df115175f5046803ada036563be1ca802f7ad0cd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py @@ -33,7 +33,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_parsing_ops +from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.platform import googletest from tensorflow.python.platform import test @@ -76,7 +76,7 @@ class CsvDatasetOpTest(test.TestCase): filenames = self.setup_files(inputs) dataset_expected = core_readers.TextLineDataset(filenames) dataset_expected = dataset_expected.map( - lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) + lambda l: parsing_ops.decode_csv(l, **kwargs)) dataset_actual = readers.CsvDataset(filenames, **kwargs) return (dataset_actual, dataset_expected) @@ -162,9 +162,28 @@ class CsvDatasetOpTest(test.TestCase): expected_err_re='Unquoted fields cannot have quotes inside', record_defaults=record_defaults) + def testCsvDataset_errWithUnescapedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['"a"b","c","d"']] + self._test_dataset( + inputs, + expected_err_re= + 'Quote inside a string has to be escaped by another quote', + record_defaults=record_defaults) + + def testCsvDataset_ignoreErrWithUnescapedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']] + filenames = self.setup_files(inputs) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) + dataset = dataset.apply(error_ops.ignore_errors()) + self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']]) + def testCsvDataset_ignoreErrWithUnquotedQuotes(self): record_defaults = [['']] * 3 - inputs = [['1,2"3,4', 'a,b,c"d', 'e,f,g']] + inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']] filenames = self.setup_files(inputs) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: @@ -562,7 +581,7 @@ class CsvDatasetBenchmark(test.Benchmark): num_cols = self._num_cols[i] kwargs = {'record_defaults': [[0.0]] * num_cols} dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() - dataset = dataset.map(lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop + dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop self._runBenchmark(dataset, num_cols, 'csv_float_map_decode_csv') self._tearDown() @@ -572,7 +591,7 @@ class CsvDatasetBenchmark(test.Benchmark): num_cols = self._num_cols[i] kwargs = {'record_defaults': [['']] * num_cols} dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() - dataset = dataset.map(lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop + dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop self._runBenchmark(dataset, num_cols, 'csv_strings_map_decode_csv') self._tearDown() diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py index a842502cc6fe3605dde0be5f50cf46e3e37d7ed4..a2ab3de52e8e512e3cba399f7a1725e5570cfd01 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py @@ -17,14 +17,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import batching from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes -from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -70,63 +66,5 @@ class DatasetConstructorTest(test.TestCase): # pylint: enable=protected-access -class DatasetConstructorSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_tensor_dataset(self, variable_array): - components = (variable_array, np.array([1, 2, 3]), np.array(37.0)) - - return dataset_ops.Dataset.from_tensors(components) - - def testFromTensorsCore(self): - # Equal length components - arr = np.array(1) - num_outputs = 1 - diff_arr = np.array(2) - self.run_core_tests(lambda: self._build_tensor_dataset(arr), - lambda: self._build_tensor_dataset(diff_arr), - num_outputs) - - def _build_tensor_slices_dataset(self, components): - return dataset_ops.Dataset.from_tensor_slices(components) - - def testFromTensorSlicesCore(self): - # Equal length components - components = (np.tile(np.array([[1], [2], [3], [4]]), 20), - np.tile(np.array([[12], [13], [14], [15]]), 22), - np.array([37.0, 38.0, 39.0, 40.0])) - - diff_comp = (np.tile(np.array([[1], [2], [3], [4]]), 20), - np.tile(np.array([[5], [6], [7], [8]]), 22), - np.array([1.0, 2.0, 3.0, 4.0])) - - dict_components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]} - - self.run_core_tests(lambda: self._build_tensor_slices_dataset(components), - lambda: self._build_tensor_slices_dataset(diff_comp), 4) - self.run_core_tests( - lambda: self._build_tensor_slices_dataset(dict_components), None, 3) - - def _build_sparse_tensor_slice_dataset(self, slices): - indices = np.array( - [[i, j] for i in range(len(slices)) for j in range(len(slices[i]))], - dtype=np.int64) - values = np.array([val for s in slices for val in s], dtype=np.float64) - dense_shape = np.array( - [len(slices), max(len(s) for s in slices) + 1], dtype=np.int64) - sparse_components = sparse_tensor.SparseTensor(indices, values, dense_shape) - return dataset_ops.Dataset.from_sparse_tensor_slices(sparse_components) - - def testFromSparseTensorSlicesCore(self): - slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []] - diff_slices = [[1., 2.], [2.], [2., 3., 4.], [], [], []] - - self.run_core_tests( - lambda: self._build_sparse_tensor_slice_dataset(slices), - lambda: self._build_sparse_tensor_slice_dataset(diff_slices), - 9, - sparse_tensors=True) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py index 34b6a080c0aae7dfc228746139acc52cea4e6f28..9b1857de1a96c8f71788a1bf5085ef0605417fe7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py @@ -19,7 +19,6 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors @@ -34,8 +33,8 @@ class DirectedInterleaveDatasetTest(test.TestCase): input_datasets = [ dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10) ] - dataset = interleave_ops.DirectedInterleaveDataset(selector_dataset, - input_datasets) + dataset = interleave_ops._DirectedInterleaveDataset(selector_dataset, + input_datasets) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() @@ -144,24 +143,5 @@ class DirectedInterleaveDatasetTest(test.TestCase): ], choice_dataset=dataset_ops.Dataset.from_tensors([1.0])) -class SampleFromDatasetsSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_dataset(self, probs, num_samples): - dataset = interleave_ops.sample_from_datasets( - [ - dataset_ops.Dataset.from_tensors(i).repeat(None) - for i in range(len(probs)) - ], - probs, - seed=1813) - return dataset.take(num_samples) - - def testSerializationCore(self): - self.run_core_tests( - lambda: self._build_dataset([0.5, 0.5], 100), - lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index bee561e3e23a2ab6f314894caa21785347e6ca8b..44c3325a3db84bb844b7f860a7c925982f1e3d6a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -22,10 +22,8 @@ import math import threading import time -import numpy as np from six.moves import zip_longest -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes @@ -38,132 +36,6 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test -class InterleaveDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_iterator_graph(self, input_values, cycle_length, block_length): - repeat_count = 2 - return dataset_ops.Dataset.from_tensor_slices(input_values).repeat( - repeat_count).interleave( - lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), - cycle_length, block_length) - - def testSerializationCore(self): - input_values = np.array([4, 5, 6], dtype=np.int64) - num_outputs = np.sum(input_values) * 2 - # cycle_length > 1, block_length > 1 - cycle_length = 2 - block_length = 3 - # pylint: disable=g-long-lambda - self.run_core_tests( - lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), - lambda: self._build_iterator_graph( - input_values, cycle_length * 2, block_length * 1), - num_outputs) - # cycle_length = 1 - cycle_length = 1 - block_length = 3 - self.run_core_tests( - lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), - None, num_outputs) - # block_length = 1 - cycle_length = 2 - block_length = 1 - self.run_core_tests( - lambda: self._build_iterator_graph( - input_values, cycle_length, block_length), - None, num_outputs) - # pylint: enable=g-long-lambda - - def testSparseCore(self): - - def _map_fn(i): - return sparse_tensor.SparseTensorValue( - indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) - - def _interleave_fn(x): - return dataset_ops.Dataset.from_tensor_slices( - sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) - - def _build_dataset(): - return dataset_ops.Dataset.range(10).map(_map_fn).interleave( - _interleave_fn, cycle_length=1) - - self.run_core_tests(_build_dataset, None, 20) - - -class ParallelInterleaveDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def setUp(self): - self.input_values = np.array([4, 5, 6], dtype=np.int64) - self.num_repeats = 2 - self.num_outputs = np.sum(self.input_values) * 2 - - def _build_ds(self, cycle_length, block_length, sloppy=False): - return (dataset_ops.Dataset.from_tensor_slices( - self.input_values).repeat(self.num_repeats).apply( - interleave_ops.parallel_interleave( - lambda x: dataset_ops.Dataset.range(10 * x, 11 * x), - cycle_length, block_length, sloppy))) - - def testSerializationCore(self): - # cycle_length > 1, block_length > 1 - cycle_length = 2 - block_length = 3 - self.run_core_tests( - lambda: self._build_ds(cycle_length, block_length), - lambda: self._build_ds(cycle_length * 2, block_length * 1), - self.num_outputs) - # cycle_length = 1 - cycle_length = 1 - block_length = 3 - self.run_core_tests(lambda: self._build_ds(cycle_length, block_length), - None, self.num_outputs) - # block_length = 1 - cycle_length = 2 - block_length = 1 - self.run_core_tests(lambda: self._build_ds(cycle_length, block_length), - None, self.num_outputs) - - def testSerializationWithSloppy(self): - break_points = self.gen_break_points(self.num_outputs, 10) - expected_outputs = np.repeat( - np.concatenate([np.arange(10 * x, 11 * x) for x in self.input_values]), - self.num_repeats).tolist() - - def run_test(cycle_length, block_length): - actual = self.gen_outputs( - lambda: self._build_ds(cycle_length, block_length, True), - break_points, self.num_outputs) - self.assertSequenceEqual(sorted(actual), expected_outputs) - - # cycle_length > 1, block_length > 1 - run_test(2, 3) - # cycle_length = 1 - run_test(1, 3) - # block_length = 1 - run_test(2, 1) - - def testSparseCore(self): - - def _map_fn(i): - return sparse_tensor.SparseTensorValue( - indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) - - def _interleave_fn(x): - return dataset_ops.Dataset.from_tensor_slices( - sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) - - def _build_dataset(): - return dataset_ops.Dataset.range(10).map(_map_fn).apply( - interleave_ops.parallel_interleave(_interleave_fn, 1)) - - self.run_core_tests(_build_dataset, None, 20) - - class ParallelInterleaveDatasetTest(test.TestCase): def setUp(self): diff --git a/tensorflow/contrib/data/python/ops/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py similarity index 100% rename from tensorflow/contrib/data/python/ops/iterator_ops_test.py rename to tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index 8d4042927970cab2f5a518fc0da49b38444dbcdf..270a2297b4d7b4fc44e3d1fa0aea8c9dfa5f39d3 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -21,20 +21,12 @@ import os import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import error_ops from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import io_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -143,229 +135,5 @@ class MapDatasetTest(test.TestCase): sess.run(get_next) -class MapDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def setUp(self): - self._tensor_slice_len = 7 - self._num_epochs = 14 - self._num_outputs = self._tensor_slice_len * self._num_epochs - - def _build_ds(self, multiplier=37.0): - components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * - np.arange(self._tensor_slice_len)[:, np.newaxis], - np.array(multiplier) * np.arange(self._tensor_slice_len)) - - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - return ( - dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) - .repeat(self._num_epochs)) - - def testSaveRestoreCore(self): - self.run_core_tests( - self._build_ds, - lambda: self._build_ds(multiplier=15.0), - self._num_outputs) - - def testSaveStatefulFunction(self): - - def _build_ds(): - - def _map_fn(x): - return random_ops.random_uniform( - (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) - - return dataset_ops.Dataset.range(100).map(_map_fn) - - self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) - - def testCaptureVariableInMapFn(self): - - def _build_ds(): - counter_var = variable_scope.get_variable( - "counter", (), dtypes.int32, use_resource=True) - return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( - lambda _: counter_var.assign_add(1))) - - self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) - - def testCaptureConstantInMapFn(self): - - def _build_ds(): - constant_var = constant_op.constant(5) - return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( - lambda x: x + constant_var)) - - self.run_core_tests(_build_ds, None, 10) - - def testCaptureDefunInMapFn(self): - num_outputs = 100 - - def _build_ds(): - - @function.Defun(dtypes.int64) - def defun_fn(x): - return constant_op.constant(1000) + math_ops.to_int32(x) - - return dataset_ops.Dataset.range(num_outputs).map(defun_fn) - - self.run_core_tests(_build_ds, None, num_outputs) - - def testBuildDefunInMapFn(self): - num_outputs = 100 - - def _build_ds(): - - @function.Defun(dtypes.int64) - def defun_fn(x): - - @function.Defun(dtypes.int32) - def defun_fn_deep(x): - return constant_op.constant(1000) + math_ops.to_int32(x) - - return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) - - return dataset_ops.Dataset.range(num_outputs).map(defun_fn) - - self.run_core_tests(_build_ds, None, num_outputs) - - def testSparseCore(self): - - def _sparse(i): - return sparse_tensor.SparseTensorValue( - indices=np.array([[0, 0]]), - values=(i * np.array([1])), - dense_shape=np.array([1, 1])) - - def _build_ds(num_outputs): - return dataset_ops.Dataset.range(num_outputs).map(_sparse) - - num_outputs = 10 - self.run_core_tests(lambda: _build_ds(num_outputs), - lambda: _build_ds(int(num_outputs / 2)), num_outputs) - - -class ParallelMapDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def setUp(self): - self._tensor_slice_len = 7 - self._num_epochs = 1 - self._num_outputs = self._tensor_slice_len * self._num_epochs - - def _build_ds(self, multiplier=37.0): - components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * - np.arange(self._tensor_slice_len)[:, np.newaxis], - np.array(multiplier) * np.arange(self._tensor_slice_len)) - - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - return (dataset_ops.Dataset.from_tensor_slices(components).map( - _map_fn, num_parallel_calls=3).repeat(self._num_epochs)) - - def _build_ds_with_prefetch(self, multiplier=37.0): - components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * - np.arange(self._tensor_slice_len)[:, np.newaxis], - np.array(multiplier) * np.arange(self._tensor_slice_len)) - - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - return (dataset_ops.Dataset.from_tensor_slices(components).map( - _map_fn, num_parallel_calls=3).repeat(self._num_epochs).prefetch(5)) - - def testSaveRestoreCore(self): - for ds_fn in [self._build_ds, self._build_ds_with_prefetch]: - self.run_core_tests( - ds_fn, - lambda: ds_fn(multiplier=15.0), - self._num_outputs) - - def testSaveStatefulFunction(self): - - def _build_ds(): - - def _map_fn(x): - return random_ops.random_uniform( - (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) - - return dataset_ops.Dataset.range(100).map( - _map_fn, num_parallel_calls=2).prefetch(2) - - self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) - - def testCaptureVariableInMapFn(self): - - def _build_ds(): - counter_var = variable_scope.get_variable( - "counter", (), dtypes.int32, use_resource=True) - return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( - lambda _: counter_var.assign_add(1), - num_parallel_calls=2).prefetch(2)) - - self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) - - def testCaptureConstantInMapFn(self): - - def _build_ds(): - constant_var = constant_op.constant(5) - return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( - lambda x: x + constant_var, num_parallel_calls=2).prefetch(2)) - - self.run_core_tests(_build_ds, None, 10) - - def testCaptureDefunInMapFn(self): - num_outputs = 100 - - def _build_ds(): - - @function.Defun(dtypes.int64) - def defun_fn(x): - return constant_op.constant(1000) + math_ops.to_int32(x) - - return dataset_ops.Dataset.range(num_outputs).map( - defun_fn, num_parallel_calls=2).prefetch(2) - - self.run_core_tests(_build_ds, None, num_outputs) - - def testBuildDefunInMapFn(self): - num_outputs = 100 - - def _build_ds(): - - @function.Defun(dtypes.int64) - def defun_fn(x): - - @function.Defun(dtypes.int32) - def defun_fn_deep(x): - return constant_op.constant(1000) + math_ops.to_int32(x) - - return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) - - return dataset_ops.Dataset.range(num_outputs).map( - defun_fn, num_parallel_calls=2).prefetch(2) - - self.run_core_tests(_build_ds, None, num_outputs) - - -class IgnoreErrorsSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_ds(self, components): - return dataset_ops.Dataset.from_tensor_slices(components).map( - lambda x: array_ops.check_numerics(x, "message")).apply( - error_ops.ignore_errors()) - - def testIgnoreErrorsCore(self): - components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) - diff_components = np.array([1., 2., 3., np.nan]).astype(np.float32) - num_outputs = 4 - self.run_core_tests(lambda: self._build_ds(components), - lambda: self._build_ds(diff_components), num_outputs) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py index 30f1847dcddbfaf379ef2b09185f7a8db4aaeae2..e35be8a23f3706bd170c09b967b4f419fc9a626e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import optimization from tensorflow.core.framework import graph_pb2 from tensorflow.python.data.ops import dataset_ops @@ -73,17 +72,5 @@ class OptimizeDatasetTest(test.TestCase): sess.run(get_next) -class OptimizeDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def testCore(self): - - def build_dataset(num_elements, batch_size): - return dataset_ops.Dataset.range(num_elements).map(lambda x: x * x).batch( - batch_size).apply(optimization.optimize(["map_and_batch_fusion"])) - - self.run_core_tests(lambda: build_dataset(200, 10), None, 20) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index 80e1cb0041024b68bd5268b5de5d69c88c839896..592642da0cfd84e50cb20d9b2e534411faf927e8 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -17,21 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import counter from tensorflow.contrib.data.python.ops import enumerate_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import io_ops -from tensorflow.python.ops import parsing_ops -from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -81,88 +73,5 @@ class RangeDatasetTest(test.TestCase): self.assertEqual(-2, sess.run(negative_get_next)) -class RangeDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _iterator_checkpoint_prefix_local(self): - return os.path.join(self.get_temp_dir(), "iterator") - - def _save_op(self, iterator_resource): - iterator_state_variant = gen_dataset_ops.serialize_iterator( - iterator_resource) - save_op = io_ops.write_file( - self._iterator_checkpoint_prefix_local(), - parsing_ops.serialize_tensor(iterator_state_variant)) - return save_op - - def _restore_op(self, iterator_resource): - iterator_state_variant = parsing_ops.parse_tensor( - io_ops.read_file(self._iterator_checkpoint_prefix_local()), - dtypes.variant) - restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, - iterator_state_variant) - return restore_op - - def testSaveRestore(self): - - def _build_graph(start, stop): - iterator = dataset_ops.Dataset.range(start, - stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - # Saving and restoring in different sessions. - start = 2 - stop = 10 - break_point = 5 - with ops.Graph().as_default() as g: - init_op, get_next, save_op, _ = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - - with ops.Graph().as_default() as g: - init_op, get_next, _, restore_op = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(init_op) - sess.run(restore_op) - for i in range(break_point, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Saving and restoring in same session. - with ops.Graph().as_default() as g: - init_op, get_next, save_op, restore_op = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - sess.run(restore_op) - for i in range(break_point, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def _build_range_dataset(self, start, stop): - return dataset_ops.Dataset.range(start, stop) - - def testRangeCore(self): - start = 2 - stop = 10 - stop_1 = 8 - self.run_core_tests(lambda: self._build_range_dataset(start, stop), - lambda: self._build_range_dataset(start, stop_1), - stop - start) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index e0237198b7d47eb98eeffe88d28bf9681b2722c6..9df403ef50e459d94b8edf3f651c7c95baf3ec42 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -17,426 +17,24 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import gzip import os -import zlib import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base from tensorflow.contrib.data.python.ops import readers -from tensorflow.core.example import example_pb2 -from tensorflow.core.example import feature_pb2 -from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.lib.io import python_io -from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import string_ops from tensorflow.python.platform import test -from tensorflow.python.util import compat -class TextLineDatasetTestBase(test.TestCase): - - def _lineText(self, f, l): - return compat.as_bytes("%d: %d" % (f, l)) - - def _createFiles(self, - num_files, - num_lines, - crlf=False, - compression_type=None): - filenames = [] - for i in range(num_files): - fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i) - filenames.append(fn) - contents = [] - for j in range(num_lines): - contents.append(self._lineText(i, j)) - # Always include a newline after the record unless it is - # at the end of the file, in which case we include it - if j + 1 != num_lines or i == 0: - contents.append(b"\r\n" if crlf else b"\n") - contents = b"".join(contents) - - if not compression_type: - with open(fn, "wb") as f: - f.write(contents) - elif compression_type == "GZIP": - with gzip.GzipFile(fn, "wb") as f: - f.write(contents) - elif compression_type == "ZLIB": - contents = zlib.compress(contents) - with open(fn, "wb") as f: - f.write(contents) - else: - raise ValueError("Unsupported compression_type", compression_type) - - return filenames - - -class TextLineDatasetSerializationTest( - TextLineDatasetTestBase, - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_iterator_graph(self, test_filenames, compression_type=None): - return core_readers.TextLineDataset( - test_filenames, compression_type=compression_type, buffer_size=10) - - def testTextLineCore(self): - compression_types = [None, "GZIP", "ZLIB"] - num_files = 5 - lines_per_file = 5 - num_outputs = num_files * lines_per_file - for compression_type in compression_types: - test_filenames = self._createFiles( - num_files, - lines_per_file, - crlf=True, - compression_type=compression_type) - # pylint: disable=cell-var-from-loop - self.run_core_tests( - lambda: self._build_iterator_graph(test_filenames, compression_type), - lambda: self._build_iterator_graph(test_filenames), num_outputs) - # pylint: enable=cell-var-from-loop - - -class FixedLengthRecordReaderTestBase(test.TestCase): - - def setUp(self): - super(FixedLengthRecordReaderTestBase, self).setUp() - self._num_files = 2 - self._num_records = 7 - self._header_bytes = 5 - self._record_bytes = 3 - self._footer_bytes = 2 - - def _record(self, f, r): - return compat.as_bytes(str(f * 2 + r) * self._record_bytes) - - def _createFiles(self): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) - filenames.append(fn) - with open(fn, "wb") as f: - f.write(b"H" * self._header_bytes) - for j in range(self._num_records): - f.write(self._record(i, j)) - f.write(b"F" * self._footer_bytes) - return filenames - - -class FixedLengthRecordDatasetSerializationTest( - FixedLengthRecordReaderTestBase, - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_iterator_graph(self, num_epochs, compression_type=None): - filenames = self._createFiles() - return core_readers.FixedLengthRecordDataset( - filenames, self._record_bytes, self._header_bytes, - self._footer_bytes).repeat(num_epochs) - - def testFixedLengthRecordCore(self): - num_epochs = 5 - num_outputs = num_epochs * self._num_files * self._num_records - self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), - lambda: self._build_iterator_graph(num_epochs * 2), - num_outputs) - - -class TFRecordDatasetTestBase(test.TestCase): - - def setUp(self): - super(TFRecordDatasetTestBase, self).setUp() - self._num_files = 2 - self._num_records = 7 - - self.test_filenames = self._createFiles() - - self.filenames = array_ops.placeholder(dtypes.string, shape=[None]) - self.num_epochs = array_ops.placeholder_with_default( - constant_op.constant(1, dtypes.int64), shape=[]) - self.compression_type = array_ops.placeholder_with_default("", shape=[]) - self.batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - - repeat_dataset = core_readers.TFRecordDataset( - self.filenames, self.compression_type).repeat(self.num_epochs) - batch_dataset = repeat_dataset.batch(self.batch_size) - - iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) - self.init_op = iterator.make_initializer(repeat_dataset) - self.init_batch_op = iterator.make_initializer(batch_dataset) - self.get_next = iterator.get_next() - - def _record(self, f, r): - return compat.as_bytes("Record %d of file %d" % (r, f)) - - def _createFiles(self): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) - filenames.append(fn) - writer = python_io.TFRecordWriter(fn) - for j in range(self._num_records): - writer.write(self._record(i, j)) - writer.close() - return filenames - - -class TFRecordDatasetSerializationTest( - TFRecordDatasetTestBase, - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_iterator_graph(self, - num_epochs, - batch_size=1, - compression_type=None, - buffer_size=None): - filenames = self._createFiles() - if compression_type is "ZLIB": - zlib_files = [] - for i, fn in enumerate(filenames): - with open(fn, "rb") as f: - cdata = zlib.compress(f.read()) - zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i) - with open(zfn, "wb") as f: - f.write(cdata) - zlib_files.append(zfn) - filenames = zlib_files - - elif compression_type is "GZIP": - gzip_files = [] - for i, fn in enumerate(self.test_filenames): - with open(fn, "rb") as f: - gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i) - with gzip.GzipFile(gzfn, "wb") as gzf: - gzf.write(f.read()) - gzip_files.append(gzfn) - filenames = gzip_files - - return core_readers.TFRecordDataset( - filenames, compression_type, - buffer_size=buffer_size).repeat(num_epochs).batch(batch_size) - - def testTFRecordWithoutBufferCore(self): - num_epochs = 5 - batch_size = num_epochs - num_outputs = num_epochs * self._num_files * self._num_records // batch_size - # pylint: disable=g-long-lambda - self.run_core_tests( - lambda: self._build_iterator_graph(num_epochs, batch_size, - buffer_size=0), - lambda: self._build_iterator_graph(num_epochs * 2, batch_size), - num_outputs) - self.run_core_tests( - lambda: self._build_iterator_graph(num_epochs, buffer_size=0), None, - num_outputs * batch_size) - # pylint: enable=g-long-lambda - - def testTFRecordWithBufferCore(self): - num_epochs = 5 - num_outputs = num_epochs * self._num_files * self._num_records - self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), - lambda: self._build_iterator_graph(num_epochs * 2), - num_outputs) - - def testTFRecordWithCompressionCore(self): - num_epochs = 5 - num_outputs = num_epochs * self._num_files * self._num_records - self.run_core_tests( - lambda: self._build_iterator_graph(num_epochs, compression_type="ZLIB"), - lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) - self.run_core_tests( - lambda: self._build_iterator_graph(num_epochs, compression_type="GZIP"), - lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) - - -def _interleave(iterators, cycle_length): - pending_iterators = iterators - open_iterators = [] - num_open = 0 - for i in range(cycle_length): - if pending_iterators: - open_iterators.append(pending_iterators.pop(0)) - num_open += 1 - - while num_open: - for i in range(min(cycle_length, len(open_iterators))): - if open_iterators[i] is None: - continue - try: - yield next(open_iterators[i]) - except StopIteration: - if pending_iterators: - open_iterators[i] = pending_iterators.pop(0) - else: - open_iterators[i] = None - num_open -= 1 - - -class ReadBatchFeaturesTest(test.TestCase): - - def setUp(self): - super(ReadBatchFeaturesTest, self).setUp() - self._num_files = 2 - self._num_records = 7 - self.test_filenames = self._createFiles() - - def _read_batch_features(self, - filenames, - num_epochs, - batch_size, - reader_num_threads=1, - parser_num_threads=1, - shuffle=False, - shuffle_seed=None, - drop_final_batch=False): - self.filenames = filenames - self.num_epochs = num_epochs - self.batch_size = batch_size - - return readers.make_batched_features_dataset( - file_pattern=self.filenames, - batch_size=self.batch_size, - features={ - "file": parsing_ops.FixedLenFeature([], dtypes.int64), - "record": parsing_ops.FixedLenFeature([], dtypes.int64), - "keywords": parsing_ops.VarLenFeature(dtypes.string) - }, - reader=core_readers.TFRecordDataset, - num_epochs=self.num_epochs, - shuffle=shuffle, - shuffle_seed=shuffle_seed, - reader_num_threads=reader_num_threads, - parser_num_threads=parser_num_threads, - drop_final_batch=drop_final_batch).make_one_shot_iterator( - ).get_next() - - def _record(self, f, r): - example = example_pb2.Example( - features=feature_pb2.Features( - feature={ - "file": - feature_pb2.Feature( - int64_list=feature_pb2.Int64List(value=[f])), - "record": - feature_pb2.Feature( - int64_list=feature_pb2.Int64List(value=[r])), - "keywords": - feature_pb2.Feature( - bytes_list=feature_pb2.BytesList( - value=self._get_keywords(f, r))) - })) - return example.SerializeToString() - - def _get_keywords(self, f, r): - num_keywords = 1 + (f + r) % 2 - keywords = [] - for index in range(num_keywords): - keywords.append(compat.as_bytes("keyword%d" % index)) - return keywords - - def _createFiles(self): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) - filenames.append(fn) - writer = python_io.TFRecordWriter(fn) - for j in range(self._num_records): - writer.write(self._record(i, j)) - writer.close() - return filenames - - def _run_actual_batch(self, outputs, sess): - file_op = outputs["file"] - keywords_indices_op = outputs["keywords"].indices - keywords_values_op = outputs["keywords"].values - keywords_dense_shape_op = outputs["keywords"].dense_shape - record_op = outputs["record"] - return sess.run([ - file_op, keywords_indices_op, keywords_values_op, - keywords_dense_shape_op, record_op - ]) - - def _next_actual_batch(self, sess): - return self._run_actual_batch(self.outputs, sess) - - def _next_expected_batch(self, - file_indices, - batch_size, - num_epochs, - cycle_length=1): - - def _next_record(file_indices): - for j in file_indices: - for i in range(self._num_records): - yield j, i - - def _next_record_interleaved(file_indices, cycle_length): - return _interleave([_next_record([i]) for i in file_indices], - cycle_length) - - file_batch = [] - keywords_batch_indices = [] - keywords_batch_values = [] - keywords_batch_max_len = 0 - record_batch = [] - batch_index = 0 - for _ in range(num_epochs): - if cycle_length == 1: - next_records = _next_record(file_indices) - else: - next_records = _next_record_interleaved(file_indices, cycle_length) - for record in next_records: - f = record[0] - r = record[1] - file_batch.append(f) - record_batch.append(r) - keywords = self._get_keywords(f, r) - keywords_batch_values.extend(keywords) - keywords_batch_indices.extend( - [[batch_index, i] for i in range(len(keywords))]) - batch_index += 1 - keywords_batch_max_len = max(keywords_batch_max_len, len(keywords)) - if len(file_batch) == batch_size: - yield [ - file_batch, keywords_batch_indices, keywords_batch_values, - [batch_size, keywords_batch_max_len], record_batch - ] - file_batch = [] - keywords_batch_indices = [] - keywords_batch_values = [] - keywords_batch_max_len = 0 - record_batch = [] - batch_index = 0 - if file_batch: - yield [ - file_batch, keywords_batch_indices, keywords_batch_values, - [len(file_batch), keywords_batch_max_len], record_batch - ] - - def _verify_records(self, - sess, - batch_size, - file_index=None, - num_epochs=1, - interleave_cycle_length=1): - if file_index is not None: - file_indices = [file_index] - else: - file_indices = range(self._num_files) - - for expected_batch in self._next_expected_batch( - file_indices, batch_size, num_epochs, interleave_cycle_length): - actual_batch = self._next_actual_batch(sess) - for i in range(len(expected_batch)): - self.assertAllEqual(expected_batch[i], actual_batch[i]) +class ReadBatchFeaturesTest( + reader_dataset_ops_test_base.ReadBatchFeaturesTestBase): def testRead(self): for batch_size in [1, 2]: @@ -444,33 +42,33 @@ class ReadBatchFeaturesTest(test.TestCase): with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: # Basic test: read from file 0. - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, - batch_size=batch_size) - self._verify_records(sess, batch_size, 0, num_epochs=num_epochs) + batch_size=batch_size).make_one_shot_iterator().get_next() + self.verify_records(sess, batch_size, 0, num_epochs=num_epochs) with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch(sess) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: # Basic test: read from file 1. - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames[1], num_epochs=num_epochs, - batch_size=batch_size) - self._verify_records(sess, batch_size, 1, num_epochs=num_epochs) + batch_size=batch_size).make_one_shot_iterator().get_next() + self.verify_records(sess, batch_size, 1, num_epochs=num_epochs) with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch(sess) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: # Basic test: read from both files. - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames, num_epochs=num_epochs, - batch_size=batch_size) - self._verify_records(sess, batch_size, num_epochs=num_epochs) + batch_size=batch_size).make_one_shot_iterator().get_next() + self.verify_records(sess, batch_size, num_epochs=num_epochs) with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch(sess) @@ -504,18 +102,18 @@ class ReadBatchFeaturesTest(test.TestCase): # Test that shuffling with same seed produces the same result. with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: - outputs1 = self._read_batch_features( + outputs1 = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, shuffle=True, - shuffle_seed=5) - outputs2 = self._read_batch_features( + shuffle_seed=5).make_one_shot_iterator().get_next() + outputs2 = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, shuffle=True, - shuffle_seed=5) + shuffle_seed=5).make_one_shot_iterator().get_next() for _ in range(total_records // batch_size): batch1 = self._run_actual_batch(outputs1, sess) batch2 = self._run_actual_batch(outputs2, sess) @@ -525,18 +123,18 @@ class ReadBatchFeaturesTest(test.TestCase): # Test that shuffling with different seeds produces a different order. with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: - outputs1 = self._read_batch_features( + outputs1 = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, shuffle=True, - shuffle_seed=5) - outputs2 = self._read_batch_features( + shuffle_seed=5).make_one_shot_iterator().get_next() + outputs2 = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, shuffle=True, - shuffle_seed=15) + shuffle_seed=15).make_one_shot_iterator().get_next() all_equal = True for _ in range(total_records // batch_size): batch1 = self._run_actual_batch(outputs1, sess) @@ -552,13 +150,14 @@ class ReadBatchFeaturesTest(test.TestCase): for parser_num_threads in [2, 4]: with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames, num_epochs=num_epochs, batch_size=batch_size, reader_num_threads=reader_num_threads, - parser_num_threads=parser_num_threads) - self._verify_records( + parser_num_threads=parser_num_threads).make_one_shot_iterator( + ).get_next() + self.verify_records( sess, batch_size, num_epochs=num_epochs, @@ -571,11 +170,11 @@ class ReadBatchFeaturesTest(test.TestCase): for num_epochs in [1, 10]: with ops.Graph().as_default(): # Basic test: read from file 0. - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, - drop_final_batch=True) + drop_final_batch=True).make_one_shot_iterator().get_next() for _, tensor in self.outputs.items(): if isinstance(tensor, ops.Tensor): # Guard against SparseTensor. self.assertEqual(tensor.shape[0], batch_size) @@ -1069,7 +668,30 @@ class MakeCsvDatasetTest(test.TestCase): self.assertFalse(all_equal) -class MakeTFRecordDatasetTest(TFRecordDatasetTestBase): +class MakeTFRecordDatasetTest( + reader_dataset_ops_test_base.TFRecordDatasetTestBase): + + def _interleave(self, iterators, cycle_length): + pending_iterators = iterators + open_iterators = [] + num_open = 0 + for i in range(cycle_length): + if pending_iterators: + open_iterators.append(pending_iterators.pop(0)) + num_open += 1 + + while num_open: + for i in range(min(cycle_length, len(open_iterators))): + if open_iterators[i] is None: + continue + try: + yield next(open_iterators[i]) + except StopIteration: + if pending_iterators: + open_iterators[i] = pending_iterators.pop(0) + else: + open_iterators[i] = None + num_open -= 1 def _next_expected_batch(self, file_indices, @@ -1085,8 +707,8 @@ class MakeTFRecordDatasetTest(TFRecordDatasetTestBase): yield j, i def _next_record_interleaved(file_indices, cycle_length): - return _interleave([_next_record([i]) for i in file_indices], - cycle_length) + return self._interleave([_next_record([i]) for i in file_indices], + cycle_length) record_batch = [] batch_index = 0 diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..e63bc4c72049c61aa40314ffebe5c4366a818d46 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py @@ -0,0 +1,331 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for testing reader datasets.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gzip +import os +import zlib + +from tensorflow.contrib.data.python.ops import readers +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.lib.io import python_io +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class FixedLengthRecordDatasetTestBase(test.TestCase): + """Base class for setting up and testing FixedLengthRecordDataset.""" + + def setUp(self): + super(FixedLengthRecordDatasetTestBase, self).setUp() + self._num_files = 2 + self._num_records = 7 + self._header_bytes = 5 + self._record_bytes = 3 + self._footer_bytes = 2 + + def _record(self, f, r): + return compat.as_bytes(str(f * 2 + r) * self._record_bytes) + + def _createFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) + filenames.append(fn) + with open(fn, "wb") as f: + f.write(b"H" * self._header_bytes) + for j in range(self._num_records): + f.write(self._record(i, j)) + f.write(b"F" * self._footer_bytes) + return filenames + + +class ReadBatchFeaturesTestBase(test.TestCase): + """Base class for setting up and testing `make_batched_feature_dataset`.""" + + def setUp(self): + super(ReadBatchFeaturesTestBase, self).setUp() + self._num_files = 2 + self._num_records = 7 + self.test_filenames = self._createFiles() + + def make_batch_feature(self, + filenames, + num_epochs, + batch_size, + reader_num_threads=1, + parser_num_threads=1, + shuffle=False, + shuffle_seed=None, + drop_final_batch=False): + self.filenames = filenames + self.num_epochs = num_epochs + self.batch_size = batch_size + + return readers.make_batched_features_dataset( + file_pattern=self.filenames, + batch_size=self.batch_size, + features={ + "file": parsing_ops.FixedLenFeature([], dtypes.int64), + "record": parsing_ops.FixedLenFeature([], dtypes.int64), + "keywords": parsing_ops.VarLenFeature(dtypes.string) + }, + reader=core_readers.TFRecordDataset, + num_epochs=self.num_epochs, + shuffle=shuffle, + shuffle_seed=shuffle_seed, + reader_num_threads=reader_num_threads, + parser_num_threads=parser_num_threads, + drop_final_batch=drop_final_batch) + + def _record(self, f, r): + example = example_pb2.Example( + features=feature_pb2.Features( + feature={ + "file": + feature_pb2.Feature( + int64_list=feature_pb2.Int64List(value=[f])), + "record": + feature_pb2.Feature( + int64_list=feature_pb2.Int64List(value=[r])), + "keywords": + feature_pb2.Feature( + bytes_list=feature_pb2.BytesList( + value=self._get_keywords(f, r))) + })) + return example.SerializeToString() + + def _get_keywords(self, f, r): + num_keywords = 1 + (f + r) % 2 + keywords = [] + for index in range(num_keywords): + keywords.append(compat.as_bytes("keyword%d" % index)) + return keywords + + def _sum_keywords(self, num_files): + sum_keywords = 0 + for i in range(num_files): + for j in range(self._num_records): + sum_keywords += 1 + (i + j) % 2 + return sum_keywords + + def _createFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) + filenames.append(fn) + writer = python_io.TFRecordWriter(fn) + for j in range(self._num_records): + writer.write(self._record(i, j)) + writer.close() + return filenames + + def _run_actual_batch(self, outputs, sess): + file_op = outputs["file"] + keywords_indices_op = outputs["keywords"].indices + keywords_values_op = outputs["keywords"].values + keywords_dense_shape_op = outputs["keywords"].dense_shape + record_op = outputs["record"] + return sess.run([ + file_op, keywords_indices_op, keywords_values_op, + keywords_dense_shape_op, record_op + ]) + + def _next_actual_batch(self, sess): + return self._run_actual_batch(self.outputs, sess) + + def _interleave(self, iterators, cycle_length): + pending_iterators = iterators + open_iterators = [] + num_open = 0 + for i in range(cycle_length): + if pending_iterators: + open_iterators.append(pending_iterators.pop(0)) + num_open += 1 + + while num_open: + for i in range(min(cycle_length, len(open_iterators))): + if open_iterators[i] is None: + continue + try: + yield next(open_iterators[i]) + except StopIteration: + if pending_iterators: + open_iterators[i] = pending_iterators.pop(0) + else: + open_iterators[i] = None + num_open -= 1 + + def _next_expected_batch(self, + file_indices, + batch_size, + num_epochs, + cycle_length=1): + + def _next_record(file_indices): + for j in file_indices: + for i in range(self._num_records): + yield j, i + + def _next_record_interleaved(file_indices, cycle_length): + return self._interleave([_next_record([i]) for i in file_indices], + cycle_length) + + file_batch = [] + keywords_batch_indices = [] + keywords_batch_values = [] + keywords_batch_max_len = 0 + record_batch = [] + batch_index = 0 + for _ in range(num_epochs): + if cycle_length == 1: + next_records = _next_record(file_indices) + else: + next_records = _next_record_interleaved(file_indices, cycle_length) + for record in next_records: + f = record[0] + r = record[1] + file_batch.append(f) + record_batch.append(r) + keywords = self._get_keywords(f, r) + keywords_batch_values.extend(keywords) + keywords_batch_indices.extend( + [[batch_index, i] for i in range(len(keywords))]) + batch_index += 1 + keywords_batch_max_len = max(keywords_batch_max_len, len(keywords)) + if len(file_batch) == batch_size: + yield [ + file_batch, keywords_batch_indices, keywords_batch_values, + [batch_size, keywords_batch_max_len], record_batch + ] + file_batch = [] + keywords_batch_indices = [] + keywords_batch_values = [] + keywords_batch_max_len = 0 + record_batch = [] + batch_index = 0 + if file_batch: + yield [ + file_batch, keywords_batch_indices, keywords_batch_values, + [len(file_batch), keywords_batch_max_len], record_batch + ] + + def verify_records(self, + sess, + batch_size, + file_index=None, + num_epochs=1, + interleave_cycle_length=1): + if file_index is not None: + file_indices = [file_index] + else: + file_indices = range(self._num_files) + + for expected_batch in self._next_expected_batch( + file_indices, batch_size, num_epochs, interleave_cycle_length): + actual_batch = self._next_actual_batch(sess) + for i in range(len(expected_batch)): + self.assertAllEqual(expected_batch[i], actual_batch[i]) + + +class TextLineDatasetTestBase(test.TestCase): + """Base class for setting up and testing TextLineDataset.""" + + def _lineText(self, f, l): + return compat.as_bytes("%d: %d" % (f, l)) + + def _createFiles(self, + num_files, + num_lines, + crlf=False, + compression_type=None): + filenames = [] + for i in range(num_files): + fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i) + filenames.append(fn) + contents = [] + for j in range(num_lines): + contents.append(self._lineText(i, j)) + # Always include a newline after the record unless it is + # at the end of the file, in which case we include it + if j + 1 != num_lines or i == 0: + contents.append(b"\r\n" if crlf else b"\n") + contents = b"".join(contents) + + if not compression_type: + with open(fn, "wb") as f: + f.write(contents) + elif compression_type == "GZIP": + with gzip.GzipFile(fn, "wb") as f: + f.write(contents) + elif compression_type == "ZLIB": + contents = zlib.compress(contents) + with open(fn, "wb") as f: + f.write(contents) + else: + raise ValueError("Unsupported compression_type", compression_type) + + return filenames + + +class TFRecordDatasetTestBase(test.TestCase): + """Base class for setting up and testing TFRecordDataset.""" + + def setUp(self): + super(TFRecordDatasetTestBase, self).setUp() + self._num_files = 2 + self._num_records = 7 + + self.test_filenames = self._createFiles() + + self.filenames = array_ops.placeholder(dtypes.string, shape=[None]) + self.num_epochs = array_ops.placeholder_with_default( + constant_op.constant(1, dtypes.int64), shape=[]) + self.compression_type = array_ops.placeholder_with_default("", shape=[]) + self.batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = core_readers.TFRecordDataset( + self.filenames, self.compression_type).repeat(self.num_epochs) + batch_dataset = repeat_dataset.batch(self.batch_size) + + iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) + self.init_op = iterator.make_initializer(repeat_dataset) + self.init_batch_op = iterator.make_initializer(batch_dataset) + self.get_next = iterator.get_next() + + def _record(self, f, r): + return compat.as_bytes("Record %d of file %d" % (r, f)) + + def _createFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) + filenames.append(fn) + writer = python_io.TFRecordWriter(fn) + for j in range(self._num_records): + writer.write(self._record(i, j)) + writer.close() + return filenames diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py index bdc003a8a5bd646e1d5c598befa2694da512d0a9..c5cfddb72b56a1bcffc80c0dd34994def3ee45cd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py @@ -17,10 +17,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin import time + from absl.testing import parameterized +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.data.python.ops import resampling from tensorflow.python.data.ops import dataset_ops diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py index eb2ceff893543f710d4f0246adf4e6367a2deeb0..d02b3abb92f49e3e53d4217662947ab97bbd0fed 100644 --- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py @@ -21,7 +21,6 @@ import itertools import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import scan_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context @@ -168,18 +167,5 @@ class ScanDatasetTest(test.TestCase): scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn)) -class ScanDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_dataset(self, num_elements): - return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply( - scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))) - - def testScanCore(self): - num_output = 5 - self.run_core_tests(lambda: self._build_dataset(num_output), - lambda: self._build_dataset(2), num_output) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..686788522acdf1c5e91132c38bdc81d10d2a0cc2 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD @@ -0,0 +1,526 @@ +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "dataset_serialization_test_base", + srcs = [ + "dataset_serialization_test_base.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:lookup_ops", + "//tensorflow/python:platform", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:iterator_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "batch_dataset_serialization_test", + size = "medium", + srcs = ["batch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "cache_dataset_serialization_test", + size = "small", + srcs = ["cache_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "concatenate_dataset_serialization_test", + size = "small", + srcs = ["concatenate_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "dataset_constructor_serialization_test", + size = "medium", + srcs = ["dataset_constructor_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "filter_dataset_serialization_test", + size = "medium", + srcs = ["filter_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "fixed_length_record_dataset_serialization_test", + size = "medium", + srcs = ["fixed_length_record_dataset_serialization_test.py"], + shard_count = 4, + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:readers", + ], +) + +py_test( + name = "flat_map_dataset_serialization_test", + size = "medium", + srcs = ["flat_map_dataset_serialization_test.py"], + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "group_by_reducer_serialization_test", + size = "medium", + srcs = ["group_by_reducer_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:grouping", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "group_by_window_serialization_test", + size = "medium", + srcs = ["group_by_window_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:grouping", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "ignore_errors_serialization_test", + size = "small", + srcs = ["ignore_errors_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:error_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "interleave_dataset_serialization_test", + size = "medium", + srcs = ["interleave_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "map_and_batch_dataset_serialization_test", + size = "medium", + srcs = ["map_and_batch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/python:client_testlib", + "//tensorflow/python:math_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "map_dataset_serialization_test", + size = "medium", + srcs = ["map_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "optimize_dataset_serialization_test", + size = "small", + srcs = ["optimize_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "padded_batch_dataset_serialization_test", + size = "medium", + srcs = ["padded_batch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:string_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "parallel_interleave_dataset_serialization_test", + size = "medium", + srcs = ["parallel_interleave_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:interleave_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "parallel_map_dataset_serialization_test", + size = "medium", + srcs = ["parallel_map_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "prefetch_dataset_serialization_test", + size = "small", + srcs = ["prefetch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "range_dataset_serialization_test", + size = "small", + srcs = ["range_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "sample_from_datasets_serialization_test", + size = "medium", + srcs = ["sample_from_datasets_serialization_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:interleave_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "scan_dataset_serialization_test", + size = "small", + srcs = ["scan_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:scan_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "sequence_dataset_serialization_test", + size = "medium", + srcs = ["sequence_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "serialization_integration_test", + size = "small", + srcs = ["serialization_integration_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "shuffle_and_repeat_dataset_serialization_test", + size = "medium", + srcs = ["shuffle_and_repeat_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:shuffle_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "shuffle_dataset_serialization_test", + size = "medium", + srcs = ["shuffle_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "sql_dataset_serialization_test", + size = "small", + srcs = ["sql_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/kernel_tests:sql_dataset_op_test_base", + "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + ], +) + +py_test( + name = "stats_dataset_serialization_test", + size = "medium", + srcs = ["stats_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:stats_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "textline_dataset_serialization_test", + size = "medium", + srcs = ["textline_dataset_serialization_test.py"], + shard_count = 4, + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:readers", + ], +) + +py_test( + name = "tf_record_dataset_serialization_test", + size = "medium", + srcs = ["tf_record_dataset_serialization_test.py"], + shard_count = 4, + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:readers", + ], +) + +py_test( + name = "unbatch_dataset_serialization_test", + size = "medium", + srcs = ["unbatch_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "unique_dataset_serialization_test", + size = "small", + srcs = ["unique_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/ops:unique", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "zip_dataset_serialization_test", + size = "small", + srcs = ["zip_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/python:client_testlib", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..af87d8b6083de268fafd4346d2871f14e0f4e7c9 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py @@ -0,0 +1,83 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the BatchDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import batching +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class BatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): + components = ( + np.arange(tensor_slice_len), + np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(tensor_slice_len)) + + return dataset_ops.Dataset.from_tensor_slices(components).batch(batch_size) + + def testCore(self): + tensor_slice_len = 8 + batch_size = 2 + num_outputs = tensor_slice_len // batch_size + self.run_core_tests( + lambda: self.build_dataset(15.0, tensor_slice_len, batch_size), + lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), + num_outputs) + + def _build_dataset_dense_to_sparse(self, components): + return dataset_ops.Dataset.from_tensor_slices(components).map( + lambda x: array_ops.fill([x], x)).apply( + batching.dense_to_sparse_batch(4, [12])) + + def testDenseToSparseBatchDatasetCore(self): + components = np.random.randint(5, size=(40,)).astype(np.int32) + diff_comp = np.random.randint(2, size=(100,)).astype(np.int32) + + num_outputs = len(components) // 4 + self.run_core_tests(lambda: self._build_dataset_dense_to_sparse(components), + lambda: self._build_dataset_dense_to_sparse(diff_comp), + num_outputs) + + def _sparse(self, i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + def _build_dataset_sparse(self, batch_size=5): + return dataset_ops.Dataset.range(10).map(self._sparse).batch(batch_size) + + def testSparseCore(self): + self.run_core_tests(self._build_dataset_sparse, + lambda: self._build_dataset_sparse(2), 2) + + def _build_dataset_nested_sparse(self): + return dataset_ops.Dataset.range(10).map(self._sparse).batch(5).batch(2) + + def testNestedSparseCore(self): + self.run_core_tests(self._build_dataset_nested_sparse, None, 1) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a0a1100893c7384b0e2bd9fcfdaa8d3698b95d28 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py @@ -0,0 +1,190 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the CacheDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class CacheDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self.range_size = 10 + self.num_repeats = 3 + self.num_outputs = self.range_size * self.num_repeats + self.cache_file_prefix = 'test' + + def ds_fn(self): + return dataset_ops.Dataset.range(self.range_size).cache( + os.path.join(self.get_temp_dir(), + self.cache_file_prefix)).repeat(self.num_repeats) + + def expected_outputs(self): + return list(range(self.range_size)) * self.num_repeats + + def testCheckpointBeforeOneEpoch(self): + # Generate 5 entries from iterator and save checkpoint. + outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + self.assertSequenceEqual(outputs, range(5)) + + # Restore from checkpoint and produce the rest of the elements from the + # iterator. + outputs.extend( + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False)) + self.assertSequenceEqual(outputs, self.expected_outputs()) + + def testCheckpointBeforeOneEpochThenRunFewSteps(self): + # Generate 8 entries from iterator but save checkpoint after producing + # 5. + outputs = self.gen_outputs( + self.ds_fn, [5], + 8, + verify_exhausted=False, + save_checkpoint_at_end=False) + self.assertSequenceEqual(outputs, range(8)) + + # Restoring from checkpoint and running GetNext should return a + # `AlreadExistsError` now because the lockfile already exists. + with self.assertRaises(errors.AlreadyExistsError): + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False) + + def testCheckpointAfterOneEpoch(self): + # Generate 15 entries from iterator and save checkpoint. + outputs = self.gen_outputs(self.ds_fn, [], 15, verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) + + # Restore from checkpoint and produce the rest of the elements from the + # iterator. + outputs.extend( + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 15, + ckpt_saved=True, + verify_exhausted=False)) + self.assertSequenceEqual(outputs, self.expected_outputs()) + + def testCheckpointAfterOneEpochThenRunFewSteps(self): + # Generate 18 entries from iterator but save checkpoint after producing + # 15. + outputs = self.gen_outputs( + self.ds_fn, [15], + 18, + verify_exhausted=False, + save_checkpoint_at_end=False) + self.assertSequenceEqual(outputs, list(range(10)) + list(range(8))) + + outputs = list(range(10)) + list(range(5)) + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 15, + ckpt_saved=True, + verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + def testCheckpointBeforeOneEpochButRunCompleteEpoch(self): + # Generate 13 entries from iterator but save checkpoint after producing + # 5. + outputs = self.gen_outputs( + self.ds_fn, [5], + 13, + verify_exhausted=False, + save_checkpoint_at_end=False) + self.assertSequenceEqual(outputs, list(range(10)) + list(range(3))) + + # Since we ran for more than one epoch, the cache was completely written. + # The ckpt was saved when the iterator was in cache-write mode. Test that + # the iterator falls back to read mode after restoring if the cache has + # been completely written. + + outputs = list(range(5)) + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + def testCheckpointUnusedWriterIterator(self): + # Checkpoint before get_next is called even once. + outputs = self.gen_outputs(self.ds_fn, [], 0, verify_exhausted=False) + self.assertSequenceEqual(outputs, []) + + outputs = self.gen_outputs( + self.ds_fn, [], + self.num_outputs, + ckpt_saved=True, + verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + def testCheckpointUnusedMidwayWriterIterator(self): + # Produce 5 elements and checkpoint. + outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + self.assertSequenceEqual(outputs, range(5)) + + # Restore from checkpoint, then produce no elements and checkpoint. + outputs.extend( + self.gen_outputs( + self.ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False)) + self.assertSequenceEqual(outputs, range(5)) + + # Restore from checkpoint and produce rest of the elements. + outputs.extend( + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False)) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + def testUnusedCheckpointError(self): + # Produce 5 elements and save ckpt. + outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + self.assertSequenceEqual(outputs, range(5)) + + # Since the complete cache has not been written, a new iterator which does + # not restore the checkpoint will throw an error since there is a partial + # cache shard. + with self.assertRaises(errors.AlreadyExistsError): + outputs = self.gen_outputs( + self.ds_fn, [], self.num_outputs, verify_exhausted=False) + + def testIgnoreCheckpointIfCacheWritten(self): + # Produce 15 elements and save ckpt. This will write the complete cache. + outputs = self.gen_outputs(self.ds_fn, [], 15, verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) + + # Build the iterator again but do not restore from ckpt. Since the cache + # has already been written we should be able to use it. + outputs = self.gen_outputs( + self.ds_fn, [], self.num_outputs, verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py similarity index 92% rename from tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py index 17f2980157ddd0350dafd1d745cbb9b64e65f7c5..96f13d75a31b6762b35062e6cf8c0cdb4d61d2c5 100644 --- a/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for the ConcatenateDataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2139b5c33db69a7ffbdebee74e5824928004b407 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py @@ -0,0 +1,95 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the dataset constructors serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.platform import test + + +class FromTensorsSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_tensor_dataset(self, variable_array): + components = (variable_array, np.array([1, 2, 3]), np.array(37.0)) + + return dataset_ops.Dataset.from_tensors(components) + + def testFromTensorsCore(self): + # Equal length components + arr = np.array(1) + num_outputs = 1 + diff_arr = np.array(2) + self.run_core_tests(lambda: self._build_tensor_dataset(arr), + lambda: self._build_tensor_dataset(diff_arr), + num_outputs) + + +class FromTensorSlicesSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_tensor_slices_dataset(self, components): + return dataset_ops.Dataset.from_tensor_slices(components) + + def testFromTensorSlicesCore(self): + # Equal length components + components = (np.tile(np.array([[1], [2], [3], [4]]), 20), + np.tile(np.array([[12], [13], [14], [15]]), 22), + np.array([37.0, 38.0, 39.0, 40.0])) + + diff_comp = (np.tile(np.array([[1], [2], [3], [4]]), 20), + np.tile(np.array([[5], [6], [7], [8]]), 22), + np.array([1.0, 2.0, 3.0, 4.0])) + + dict_components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]} + + self.run_core_tests(lambda: self._build_tensor_slices_dataset(components), + lambda: self._build_tensor_slices_dataset(diff_comp), 4) + self.run_core_tests( + lambda: self._build_tensor_slices_dataset(dict_components), None, 3) + + +class FromSparseTensorSlicesSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_sparse_tensor_slice_dataset(self, slices): + indices = np.array( + [[i, j] for i in range(len(slices)) for j in range(len(slices[i]))], + dtype=np.int64) + values = np.array([val for s in slices for val in s], dtype=np.float64) + dense_shape = np.array( + [len(slices), max(len(s) for s in slices) + 1], dtype=np.int64) + sparse_components = sparse_tensor.SparseTensor(indices, values, dense_shape) + return dataset_ops.Dataset.from_sparse_tensor_slices(sparse_components) + + def testFromSparseTensorSlicesCore(self): + slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []] + diff_slices = [[1., 2.], [2.], [2., 3., 4.], [], [], []] + + self.run_core_tests( + lambda: self._build_sparse_tensor_slice_dataset(slices), + lambda: self._build_sparse_tensor_slice_dataset(diff_slices), + 9, + sparse_tensors=True) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py similarity index 97% rename from tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py index 78ecce8f7daaf84002ae78d8d77820755b967d89..393f08850b1865180a8b94e9209b2445b54c8b69 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py @@ -467,7 +467,8 @@ class DatasetSerializationTestBase(test.TestCase): ckpt_saved=False, init_before_restore=False, sparse_tensors=False, - verify_exhausted=True): + verify_exhausted=True, + save_checkpoint_at_end=True): """Generates elements from input dataset while stopping at break points. Produces `num_outputs` outputs and saves the state of the iterator in the @@ -490,6 +491,10 @@ class DatasetSerializationTestBase(test.TestCase): sparse_tensors: Whether dataset is built from SparseTensor(s). verify_exhausted: Whether to verify that the iterator has been exhausted after producing `num_outputs` elements. + save_checkpoint_at_end: Whether to save a checkpoint after producing all + outputs. If False, checkpoints are saved each break point but not at the + end. Note that checkpoints overwrite each other so there is always only + a single checkpoint available. Defaults to True. Returns: A list of `num_outputs` items. @@ -526,8 +531,9 @@ class DatasetSerializationTestBase(test.TestCase): if i == len(break_points) and verify_exhausted: with self.assertRaises(errors.OutOfRangeError): sess.run(get_next_op) - self._save(sess, saver) - ckpt_saved = True + if save_checkpoint_at_end or i < len(break_points): + self._save(sess, saver) + ckpt_saved = True return outputs diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py similarity index 91% rename from tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py index b572d6ed770fc0fe0f852359baf343c55966eddd..7c170078a11aadce9e5730437e4c25209bd58edb 100644 --- a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for the FilterDataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import math_ops @@ -35,7 +35,7 @@ class FilterDatasetSerializationTest( def testFilterCore(self): div = 3 - num_outputs = np.sum([x % 3 is not 2 for x in range(100)]) + num_outputs = np.sum([x % 3 != 2 for x in range(100)]) self.run_core_tests(lambda: self._build_filter_range_graph(div), lambda: self._build_filter_range_graph(div * 2), num_outputs) diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..34392d88d4505175c4562e23d5f0c4116e00b022 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py @@ -0,0 +1,45 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the FixedLengthRecordDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.platform import test + + +class FixedLengthRecordDatasetSerializationTest( + reader_dataset_ops_test_base.FixedLengthRecordDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, num_epochs, compression_type=None): + filenames = self._createFiles() + return core_readers.FixedLengthRecordDataset( + filenames, self._record_bytes, self._header_bytes, + self._footer_bytes).repeat(num_epochs) + + def testFixedLengthRecordCore(self): + num_epochs = 5 + num_outputs = num_epochs * self._num_files * self._num_records + self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), + lambda: self._build_iterator_graph(num_epochs * 2), + num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py similarity index 96% rename from tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py index f3feecef32e587045be25056815315136a883ca7..16051ffd3fd1e1e7ff419f28109df7bc1f165257 100644 --- a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for the FlatMapDataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..571e0899bbc1f856d66f85c4f6f3ac78aa0b1368 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py @@ -0,0 +1,61 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the GroupByReducer serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import grouping +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class GroupByReducerSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, components): + reducer = grouping.Reducer( + init_func=lambda _: np.int64(0), + reduce_func=lambda x, y: x + y, + finalize_func=lambda x: x) + + return dataset_ops.Dataset.from_tensor_slices(components).apply( + grouping.group_by_reducer(lambda x: x % 5, reducer)) + + def testCoreGroupByReducer(self): + components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64) + self.verify_unused_iterator( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_init_before_restore( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_multiple_breaks( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_reset_restored_iterator( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_restore_in_empty_graph( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + diff_components = np.array([5, 4, 3, 2, 1, 0], dtype=np.int64) + self.verify_restore_in_modified_graph( + lambda: self._build_dataset(components), + lambda: self._build_dataset(diff_components), + 5, + verify_exhausted=True) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f86af4084ef61c2f20dbe2fb388a20287676f39d --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py @@ -0,0 +1,57 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the GroupByWindow serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import grouping +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class GroupByWindowSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, components): + return dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply( + grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4)) + + def testCoreGroupByWindow(self): + components = np.array( + [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64) + self.verify_unused_iterator( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_init_before_restore( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_multiple_breaks( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_reset_restored_iterator( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_restore_in_empty_graph( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + diff_components = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64) + self.verify_restore_in_modified_graph( + lambda: self._build_dataset(components), + lambda: self._build_dataset(diff_components), + 12, + verify_exhausted=False) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..65ae9923b8f64dddcd54afc53e2fa67bc770fc2a --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py @@ -0,0 +1,46 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the IgnoreErrors input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import error_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class IgnoreErrorsSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_ds(self, components): + return dataset_ops.Dataset.from_tensor_slices(components).map( + lambda x: array_ops.check_numerics(x, "message")).apply( + error_ops.ignore_errors()) + + def testIgnoreErrorsCore(self): + components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) + diff_components = np.array([1., 2., 3., np.nan]).astype(np.float32) + num_outputs = 4 + self.run_core_tests(lambda: self._build_ds(components), + lambda: self._build_ds(diff_components), num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ac3892fe81a1c0d325ddc5f501c2caed4b53f5d5 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py @@ -0,0 +1,86 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the InterleaveDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test + + +class InterleaveDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, input_values, cycle_length, block_length): + repeat_count = 2 + return dataset_ops.Dataset.from_tensor_slices(input_values).repeat( + repeat_count).interleave( + lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), + cycle_length, block_length) + + def testSerializationCore(self): + input_values = np.array([4, 5, 6], dtype=np.int64) + num_outputs = np.sum(input_values) * 2 + # cycle_length > 1, block_length > 1 + cycle_length = 2 + block_length = 3 + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: self._build_iterator_graph( + input_values, cycle_length, block_length), + lambda: self._build_iterator_graph( + input_values, cycle_length * 2, block_length * 1), + num_outputs) + # cycle_length = 1 + cycle_length = 1 + block_length = 3 + self.run_core_tests( + lambda: self._build_iterator_graph( + input_values, cycle_length, block_length), + None, num_outputs) + # block_length = 1 + cycle_length = 2 + block_length = 1 + self.run_core_tests( + lambda: self._build_iterator_graph( + input_values, cycle_length, block_length), + None, num_outputs) + # pylint: enable=g-long-lambda + + def testSparseCore(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _interleave_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + def _build_dataset(): + return dataset_ops.Dataset.range(10).map(_map_fn).interleave( + _interleave_fn, cycle_length=1) + + self.run_core_tests(_build_dataset, None, 20) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c9cd211328fa595c0ce0efe3509e8ba9dc06af80 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py @@ -0,0 +1,88 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the MapAndBatchDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import batching +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class MapAndBatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testNumParallelBatches(self): + range_size = 11 + num_repeats = 2 + batch_size = 5 + total_outputs = range_size * num_repeats + num_outputs_drop_remainder = total_outputs // batch_size + num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) + num_parallel_batches = 2 + + def build_ds(range_start, drop_remainder=False): + + def _map_fn(x): + return math_ops.square(x) + + return dataset_ops.Dataset.range( + range_start, range_start + range_size).repeat(num_repeats).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_batches=num_parallel_batches, + drop_remainder=drop_remainder)) + + self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), + num_outputs_keep_remainder) + self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), + num_outputs_drop_remainder) + + def testNumParallelCalls(self): + range_size = 11 + num_repeats = 2 + batch_size = 5 + total_outputs = range_size * num_repeats + num_outputs_drop_remainder = total_outputs // batch_size + num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) + num_parallel_calls = 7 + + def build_ds(range_start, drop_remainder=False): + + def _map_fn(x): + return math_ops.square(x) + + return dataset_ops.Dataset.range( + range_start, range_start + range_size).repeat(num_repeats).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_calls=num_parallel_calls, + drop_remainder=drop_remainder)) + + self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), + num_outputs_keep_remainder) + self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), + num_outputs_drop_remainder) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ab783e5cce95ed63fe64c273abb3846121c7a274 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py @@ -0,0 +1,140 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the MapDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import function +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test + + +class MapDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self._tensor_slice_len = 7 + self._num_epochs = 14 + self._num_outputs = self._tensor_slice_len * self._num_epochs + + def _build_ds(self, multiplier=37.0): + components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * + np.arange(self._tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(self._tensor_slice_len)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + return ( + dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + .repeat(self._num_epochs)) + + def testSaveRestoreCore(self): + self.run_core_tests( + self._build_ds, + lambda: self._build_ds(multiplier=15.0), + self._num_outputs) + + def testSaveStatefulFunction(self): + + def _build_ds(): + + def _map_fn(x): + return random_ops.random_uniform( + (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) + + return dataset_ops.Dataset.range(100).map(_map_fn) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureVariableInMapFn(self): + + def _build_ds(): + counter_var = variable_scope.get_variable( + "counter", (), dtypes.int32, use_resource=True) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda _: counter_var.assign_add(1))) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureConstantInMapFn(self): + + def _build_ds(): + constant_var = constant_op.constant(5) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda x: x + constant_var)) + + self.run_core_tests(_build_ds, None, 10) + + def testCaptureDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return dataset_ops.Dataset.range(num_outputs).map(defun_fn) + + self.run_core_tests(_build_ds, None, num_outputs) + + def testBuildDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + + @function.Defun(dtypes.int32) + def defun_fn_deep(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) + + return dataset_ops.Dataset.range(num_outputs).map(defun_fn) + + self.run_core_tests(_build_ds, None, num_outputs) + + def testSparseCore(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1])), + dense_shape=np.array([1, 1])) + + def _build_ds(num_outputs): + return dataset_ops.Dataset.range(num_outputs).map(_sparse) + + num_outputs = 10 + self.run_core_tests(lambda: _build_ds(num_outputs), + lambda: _build_ds(int(num_outputs / 2)), num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d5c03495e34e73018bf9832bf77cdcf038449488 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py @@ -0,0 +1,39 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the OptimizeDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class OptimizeDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testCore(self): + + def build_dataset(num_elements, batch_size): + return dataset_ops.Dataset.range(num_elements).map(lambda x: x * x).batch( + batch_size).apply(optimization.optimize(["map_and_batch_fusion"])) + + self.run_core_tests(lambda: build_dataset(200, 10), None, 20) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9ac42a461afcb6803a0e033892e74fb84d1e5e58 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py @@ -0,0 +1,66 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the PaddedBatchDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import string_ops +from tensorflow.python.platform import test + + +class PaddedBatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testPaddedBatch(self): + + def build_dataset(seq_lens): + return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( + lambda x: array_ops.fill([x], x)).padded_batch( + 4, padded_shapes=[-1]) + + seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) + seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) + self.run_core_tests(lambda: build_dataset(seq_lens1), + lambda: build_dataset(seq_lens2), 8) + + def testPaddedBatchNonDefaultPadding(self): + + def build_dataset(seq_lens): + + def fill_tuple(x): + filled = array_ops.fill([x], x) + return (filled, string_ops.as_string(filled)) + + padded_shape = [-1] + return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( + fill_tuple).padded_batch( + 4, + padded_shapes=(padded_shape, padded_shape), + padding_values=(-1, "")) + + seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) + seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) + self.run_core_tests(lambda: build_dataset(seq_lens1), + lambda: build_dataset(seq_lens2), 8) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1f8a584df902180aa7ab020b47ecc749912a3a3a --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py @@ -0,0 +1,101 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ParallelInterleaveDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test + + +class ParallelInterleaveDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self.input_values = np.array([4, 5, 6], dtype=np.int64) + self.num_repeats = 2 + self.num_outputs = np.sum(self.input_values) * 2 + + def _build_ds(self, cycle_length, block_length, sloppy=False): + return (dataset_ops.Dataset.from_tensor_slices( + self.input_values).repeat(self.num_repeats).apply( + interleave_ops.parallel_interleave( + lambda x: dataset_ops.Dataset.range(10 * x, 11 * x), + cycle_length, block_length, sloppy))) + + def testSerializationCore(self): + # cycle_length > 1, block_length > 1 + cycle_length = 2 + block_length = 3 + self.run_core_tests( + lambda: self._build_ds(cycle_length, block_length), + lambda: self._build_ds(cycle_length * 2, block_length * 1), + self.num_outputs) + # cycle_length = 1 + cycle_length = 1 + block_length = 3 + self.run_core_tests(lambda: self._build_ds(cycle_length, block_length), + None, self.num_outputs) + # block_length = 1 + cycle_length = 2 + block_length = 1 + self.run_core_tests(lambda: self._build_ds(cycle_length, block_length), + None, self.num_outputs) + + def testSerializationWithSloppy(self): + break_points = self.gen_break_points(self.num_outputs, 10) + expected_outputs = np.repeat( + np.concatenate([np.arange(10 * x, 11 * x) for x in self.input_values]), + self.num_repeats).tolist() + + def run_test(cycle_length, block_length): + actual = self.gen_outputs( + lambda: self._build_ds(cycle_length, block_length, True), + break_points, self.num_outputs) + self.assertSequenceEqual(sorted(actual), expected_outputs) + + # cycle_length > 1, block_length > 1 + run_test(2, 3) + # cycle_length = 1 + run_test(1, 3) + # block_length = 1 + run_test(2, 1) + + def testSparseCore(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _interleave_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + def _build_dataset(): + return dataset_ops.Dataset.range(10).map(_map_fn).apply( + interleave_ops.parallel_interleave(_interleave_fn, 1)) + + self.run_core_tests(_build_dataset, None, 20) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3fb7605be1f230cef4cdae30aa672842a678edf7 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py @@ -0,0 +1,139 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ParallelMapDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import function +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test + + +class ParallelMapDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self._tensor_slice_len = 7 + self._num_epochs = 1 + self._num_outputs = self._tensor_slice_len * self._num_epochs + + def _build_ds(self, multiplier=37.0): + components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * + np.arange(self._tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(self._tensor_slice_len)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + return (dataset_ops.Dataset.from_tensor_slices(components).map( + _map_fn, num_parallel_calls=3).repeat(self._num_epochs)) + + def _build_ds_with_prefetch(self, multiplier=37.0): + components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * + np.arange(self._tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(self._tensor_slice_len)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + return (dataset_ops.Dataset.from_tensor_slices(components).map( + _map_fn, num_parallel_calls=3).repeat(self._num_epochs).prefetch(5)) + + def testSaveRestoreCore(self): + for ds_fn in [self._build_ds, self._build_ds_with_prefetch]: + self.run_core_tests( + ds_fn, + lambda: ds_fn(multiplier=15.0), + self._num_outputs) + + def testSaveStatefulFunction(self): + + def _build_ds(): + + def _map_fn(x): + return random_ops.random_uniform( + (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) + + return dataset_ops.Dataset.range(100).map( + _map_fn, num_parallel_calls=2).prefetch(2) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureVariableInMapFn(self): + + def _build_ds(): + counter_var = variable_scope.get_variable( + "counter", (), dtypes.int32, use_resource=True) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda _: counter_var.assign_add(1), + num_parallel_calls=2).prefetch(2)) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureConstantInMapFn(self): + + def _build_ds(): + constant_var = constant_op.constant(5) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda x: x + constant_var, num_parallel_calls=2).prefetch(2)) + + self.run_core_tests(_build_ds, None, 10) + + def testCaptureDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return dataset_ops.Dataset.range(num_outputs).map( + defun_fn, num_parallel_calls=2).prefetch(2) + + self.run_core_tests(_build_ds, None, num_outputs) + + def testBuildDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + + @function.Defun(dtypes.int32) + def defun_fn_deep(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) + + return dataset_ops.Dataset.range(num_outputs).map( + defun_fn, num_parallel_calls=2).prefetch(2) + + self.run_core_tests(_build_ds, None, num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py similarity index 90% rename from tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py index 3d120a3071ef730f21221e3291d8c84385b51aa3..c802402461216de33e7d3232ba38063c27f33557 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for the PrefetchDataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e4f5b6cf5db788ad2fd09b7e93d0ae5ebb530a11 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py @@ -0,0 +1,118 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the RangeDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class RangeDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _iterator_checkpoint_prefix_local(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def _save_op(self, iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + self._iterator_checkpoint_prefix_local(), + parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(self, iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(self._iterator_checkpoint_prefix_local()), + dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + + def testSaveRestore(self): + + def _build_graph(start, stop): + iterator = dataset_ops.Dataset.range(start, + stop).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + # Saving and restoring in different sessions. + start = 2 + stop = 10 + break_point = 5 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, _ = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + with ops.Graph().as_default() as g: + init_op, get_next, _, restore_op = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(init_op) + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Saving and restoring in same session. + with ops.Graph().as_default() as g: + init_op, get_next, save_op, restore_op = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def _build_range_dataset(self, start, stop): + return dataset_ops.Dataset.range(start, stop) + + def testRangeCore(self): + start = 2 + stop = 10 + stop_1 = 8 + self.run_core_tests(lambda: self._build_range_dataset(start, stop), + lambda: self._build_range_dataset(start, stop_1), + stop - start) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fdb35ea624c22ad0a9561d774c86247119c4c837 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py @@ -0,0 +1,46 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the SampleFromDatasets serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class SampleFromDatasetsSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, probs, num_samples): + dataset = interleave_ops.sample_from_datasets( + [ + dataset_ops.Dataset.from_tensors(i).repeat(None) + for i in range(len(probs)) + ], + probs, + seed=1813) + return dataset.take(num_samples) + + def testSerializationCore(self): + self.run_core_tests( + lambda: self._build_dataset([0.5, 0.5], 100), + lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..af9ef48c0f3b92f61c097410ef4dfd787292e76a --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py @@ -0,0 +1,40 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ScanDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import scan_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class ScanDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, num_elements): + return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply( + scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))) + + def testScanCore(self): + num_output = 5 + self.run_core_tests(lambda: self._build_dataset(num_output), + lambda: self._build_dataset(2), num_output) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py similarity index 91% rename from tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py index d0cb203a3afd2775756c8542a1e86faedc5cee53..2afebca0f5849c640044830fff05ebff131e0875 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py @@ -12,19 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for the sequence datasets serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.platform import test -class SequenceDatasetSerializationTest( +class SkipDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): def _build_skip_dataset(self, count): @@ -52,6 +52,10 @@ class SequenceDatasetSerializationTest( 'Shape must be rank 0 but is rank 1'): self.run_core_tests(lambda: self._build_skip_dataset([1, 2]), None, 0) + +class TakeDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + def _build_take_dataset(self, count): components = (np.arange(10),) return dataset_ops.Dataset.from_tensor_slices(components).take(count) @@ -79,6 +83,10 @@ class SequenceDatasetSerializationTest( 'Shape must be rank 0 but is rank 1'): self.run_core_tests(lambda: self._build_take_dataset([1, 2]), None, 0) + +class RepeatDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + def _build_repeat_dataset(self, count, take_count=3): components = (np.arange(10),) return dataset_ops.Dataset.from_tensor_slices(components).take( @@ -117,5 +125,5 @@ class SequenceDatasetSerializationTest( None, 0) -if __name__ == "__main__": +if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization_integration_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py similarity index 96% rename from tensorflow/contrib/data/python/kernel_tests/serialization_integration_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py index 0a6b74dc3eb80a6168117beed06935737198cecb..992d996a485de94ad55305552e42c7fbc92ec64b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization_integration_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Integration test for input pipeline serialization.""" +"""Integration test for dataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -26,7 +26,7 @@ from tensorflow.python.platform import test from tensorflow.python.training import saver as saver_lib -class MultipleInputPipelinesTest(test.TestCase): +class SerializationIntegrationTest(test.TestCase): def _build_input_pipeline(self, name, num_outputs): with ops.name_scope(name): diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f199ec835ef1c72e2c3f8b3b1cc4f5fe6ea0b6f4 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py @@ -0,0 +1,39 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ShuffleAndRepeatDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import shuffle_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class ShuffleAndRepeatSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_ds(self, seed): + return dataset_ops.Dataset.range(20).apply( + shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed)) + + def testCore(self): + self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20), + 100) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d46c762aaaadc4314a10acc5aeb7ace7df5002a8 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py @@ -0,0 +1,148 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ShuffleDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import ops +from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib + + +class ShuffleDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_shuffle_dataset( + self, + range_limit=10, + num_repeats=5, + buffer_size=5, + seed=None, + reshuffle_each_iteration=None, + ): + return dataset_ops.Dataset.range(range_limit).shuffle( + buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration).repeat(num_repeats) + + def testShuffleCore(self): + + seed = 55 + range_limit = 5 + num_repeats = 2 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 5, 8, 10] + # pylint: disable=cell-var-from-loop + # pylint: disable=g-long-lambda + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + self.run_core_tests( + lambda: self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration), + lambda: self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=10, + reshuffle_each_iteration=reshuffle_each_iteration), + num_outputs) + # pylint: enable=cell-var-from-loop + # pylint: enable=g-long-lambda + + def testNonDeterministicSeeding(self): + + range_limit = 5 + num_repeats = 2 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 5, 8, 10] + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + + def ds_fn(): + # pylint: disable=cell-var-from-loop + return self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=None, # Iterator seeds are generated non-deterministically. + reshuffle_each_iteration=reshuffle_each_iteration) + # pylint: enable=cell-var-from-loop + + # We checkpoint the initial state of the Dataset so that we can restore + # the seeds in the next run. Since the seeding is non-deterministic + # the dataset gets initialized with different seeds each time. + expected = self.gen_outputs( + ds_fn, + break_points=[0], + num_outputs=num_outputs, + ckpt_saved=False, + verify_exhausted=False, + save_checkpoint_at_end=False) + actual = self.gen_outputs( + ds_fn, + break_points=self.gen_break_points(num_outputs), + num_outputs=num_outputs, + ckpt_saved=True, + verify_exhausted=False) + self.match(expected, actual) + + def testMultipleIterators(self): + range_limit = 5 + num_repeats = 2 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 5, 8, 10] + + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + + def ds_fn(): + # pylint: disable=cell-var-from-loop + return self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=None, # Iterator seeds are generated non-deterministically. + reshuffle_each_iteration=reshuffle_each_iteration) + # pylint: enable=cell-var-from-loop + + with ops.Graph().as_default() as g: + ds = ds_fn() + iterators = [ds.make_one_shot_iterator(), ds.make_one_shot_iterator()] + get_next_ops = [it.get_next() for it in iterators] + saveables = [ + contrib_iterator_ops.make_saveable_from_iterator(it) + for it in iterators + ] + for saveable in saveables: + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + saver = saver_lib.Saver(allow_empty=True) + with self.test_session(graph=g) as sess: + self._save(sess, saver) + expected = [sess.run(get_next_ops) for _ in range(num_outputs)] + self._restore(saver, sess) + actual = [sess.run(get_next_ops) for _ in range(num_outputs)] + self.match(expected, actual) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..93b26ed58a065de2074906528a0f49d696a813ff --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py @@ -0,0 +1,53 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the SqlDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.data.python.kernel_tests import sql_dataset_op_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class SqlDatasetSerializationTest( + sql_dataset_op_test_base.SqlDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, num_repeats): + data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") + driver_name = array_ops.placeholder_with_default( + array_ops.constant("sqlite", dtypes.string), shape=[]) + query = ("SELECT first_name, last_name, motto FROM students ORDER BY " + "first_name DESC") + output_types = (dtypes.string, dtypes.string, dtypes.string) + return readers.SqlDataset(driver_name, data_source_name, query, + output_types).repeat(num_repeats) + + def testSQLSaveable(self): + num_repeats = 4 + num_outputs = num_repeats * 2 + self.run_core_tests(lambda: self._build_dataset(num_repeats), + lambda: self._build_dataset(num_repeats // 2), + num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..14cd3e9c4a72cc7832f9bb1cb49c72a8a7cb2dcd --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py @@ -0,0 +1,95 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the StatsDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import stats_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +# TODO(shivaniagrawal): Can not checkpoint input_pipeline with the +# transformation `stats_ops.set_stats_aggregator`, since we don't support +# serializing StatsAggregator yet. +class StatsDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset_bytes_stats(self, num_elements): + return dataset_ops.Dataset.range(num_elements).map( + lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply( + stats_ops.bytes_produced_stats("bytes_produced")) + + def test_bytes_produced_stats_invalid_tag_shape(self): + with self.assertRaisesRegexp( + ValueError, "Shape must be rank 0 but is rank 1"): + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: dataset_ops.Dataset.range(100).apply( + stats_ops.bytes_produced_stats(["bytes_produced"])), + None, 100) + # pylint: enable=g-long-lambda + + def testBytesStatsDatasetSaveableCore(self): + num_outputs = 100 + self.run_core_tests( + lambda: self._build_dataset_bytes_stats(num_outputs), + lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs) + + def _build_dataset_latency_stats(self, num_elements, tag="record_latency"): + return dataset_ops.Dataset.range(num_elements).apply( + stats_ops.latency_stats(tag)) + + def _build_dataset_multiple_tags(self, + num_elements, + tag1="record_latency", + tag2="record_latency_2"): + return dataset_ops.Dataset.range(num_elements).apply( + stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2)) + + def test_latency_stats_invalid_tag_shape(self): + with self.assertRaisesRegexp( + ValueError, "Shape must be rank 0 but is rank 1"): + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats(["record_latency", "record_latency_2"])), + None, 100) + # pylint: enable=g-long-lambda + + def testLatencyStatsDatasetSaveableCore(self): + num_outputs = 100 + + self.run_core_tests( + lambda: self._build_dataset_latency_stats(num_outputs), + lambda: self._build_dataset_latency_stats(num_outputs // 10), + num_outputs) + + self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs), + None, num_outputs) + + tag1 = "record_latency" + tag2 = "record_latency" + self.run_core_tests( + lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2), + None, num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2483787f44f913199e3f2aa46d181d609a4a9a8f --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the TextLineDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.platform import test + + +class TextLineDatasetSerializationTest( + reader_dataset_ops_test_base.TextLineDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, test_filenames, compression_type=None): + return core_readers.TextLineDataset( + test_filenames, compression_type=compression_type, buffer_size=10) + + def testTextLineCore(self): + compression_types = [None, "GZIP", "ZLIB"] + num_files = 5 + lines_per_file = 5 + num_outputs = num_files * lines_per_file + for compression_type in compression_types: + test_filenames = self._createFiles( + num_files, + lines_per_file, + crlf=True, + compression_type=compression_type) + # pylint: disable=cell-var-from-loop + self.run_core_tests( + lambda: self._build_iterator_graph(test_filenames, compression_type), + lambda: self._build_iterator_graph(test_filenames), num_outputs) + # pylint: enable=cell-var-from-loop + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..55a6257a274cd7f78e3818943627cfa09a185fd7 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py @@ -0,0 +1,99 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the TFRecordDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gzip +import os +import zlib + +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.platform import test + + +class TFRecordDatasetSerializationTest( + reader_dataset_ops_test_base.TFRecordDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, + num_epochs, + batch_size=1, + compression_type=None, + buffer_size=None): + filenames = self._createFiles() + if compression_type == "ZLIB": + zlib_files = [] + for i, fn in enumerate(filenames): + with open(fn, "rb") as f: + cdata = zlib.compress(f.read()) + zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i) + with open(zfn, "wb") as f: + f.write(cdata) + zlib_files.append(zfn) + filenames = zlib_files + + elif compression_type == "GZIP": + gzip_files = [] + for i, fn in enumerate(self.test_filenames): + with open(fn, "rb") as f: + gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i) + with gzip.GzipFile(gzfn, "wb") as gzf: + gzf.write(f.read()) + gzip_files.append(gzfn) + filenames = gzip_files + + return core_readers.TFRecordDataset( + filenames, compression_type, + buffer_size=buffer_size).repeat(num_epochs).batch(batch_size) + + def testTFRecordWithoutBufferCore(self): + num_epochs = 5 + batch_size = num_epochs + num_outputs = num_epochs * self._num_files * self._num_records // batch_size + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, batch_size, + buffer_size=0), + lambda: self._build_iterator_graph(num_epochs * 2, batch_size), + num_outputs) + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, buffer_size=0), None, + num_outputs * batch_size) + # pylint: enable=g-long-lambda + + def testTFRecordWithBufferCore(self): + num_epochs = 5 + num_outputs = num_epochs * self._num_files * self._num_records + self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), + lambda: self._build_iterator_graph(num_epochs * 2), + num_outputs) + + def testTFRecordWithCompressionCore(self): + num_epochs = 5 + num_outputs = num_epochs * self._num_files * self._num_records + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, compression_type="ZLIB"), + lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, compression_type="GZIP"), + lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b2a5a8a20dd7a9f891b07351570006636ca34bd0 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py @@ -0,0 +1,51 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the UnbatchDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import batching +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class UnbatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): + components = ( + np.arange(tensor_slice_len), + np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(tensor_slice_len)) + + return dataset_ops.Dataset.from_tensor_slices(components).batch( + batch_size).apply(batching.unbatch()) + + def testCore(self): + tensor_slice_len = 8 + batch_size = 2 + num_outputs = tensor_slice_len + self.run_core_tests( + lambda: self.build_dataset(15.0, tensor_slice_len, batch_size), + lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), + num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..22f15b88464a770207dc7c6f0387d73ea3d5c2e4 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py @@ -0,0 +1,40 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the UniqueDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import unique +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test + + +class UniqueDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testUnique(self): + + def build_dataset(num_elements, unique_elem_range): + return dataset_ops.Dataset.range(num_elements).map( + lambda x: x % unique_elem_range).apply(unique.unique()) + + self.run_core_tests(lambda: build_dataset(200, 100), + lambda: build_dataset(40, 100), 100) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py similarity index 92% rename from tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py index e39fa957f0bbb9d3671274d5f58b993e8399814b..340a6ff72e6813c3743d3d83a72ac12d4a392b66 100644 --- a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the experimental input pipeline ops.""" +"""Tests for the ZipDataset serialization.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py index bcc644c0971854d948025009dc7add2fea214048..3c11d7a97fc9a4b2b8b19a8e82ad5e9037d6bbcd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py @@ -19,7 +19,6 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import shuffle_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors @@ -27,60 +26,25 @@ from tensorflow.python.framework import ops from tensorflow.python.platform import test -class ShuffleDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_shuffle_dataset( - self, - range_limit=10, - num_repeats=5, - buffer_size=5, - seed=None, - reshuffle_each_iteration=None, - ): - return dataset_ops.Dataset.range(range_limit).shuffle( - buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration).repeat(num_repeats) - - def testShuffleCore(self): - - seed = 55 - range_limit = 10 - num_repeats = 5 - num_outputs = range_limit * num_repeats - buffer_sizes = [1, 3, 8, 10, 25, 50] - reshuffle_each_iteration = False - # pylint: disable=cell-var-from-loop - # pylint: disable=g-long-lambda - for buffer_size in buffer_sizes: - self.run_core_tests( - lambda: self._build_shuffle_dataset( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration), - lambda: self._build_shuffle_dataset( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=10, - reshuffle_each_iteration=reshuffle_each_iteration), - num_outputs) - # pylint: enable=cell-var-from-loop - # pylint: enable=g-long-lambda - - -class ShuffleAndRepeatTest( - dataset_serialization_test_base.DatasetSerializationTestBase): +class ShuffleAndRepeatTest(test.TestCase): def _build_ds(self, seed, count=5, num_elements=20): return dataset_ops.Dataset.range(num_elements).apply( shuffle_ops.shuffle_and_repeat(buffer_size=5, count=count, seed=seed)) + def _gen_outputs(self, ds_fn, num_outputs, verify_exhausted=True): + get_next = ds_fn().make_one_shot_iterator().get_next() + outputs = [] + with self.test_session() as sess: + for _ in range(num_outputs): + outputs.append(sess.run(get_next)) + if verify_exhausted: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + return outputs + def testCorrectOutput(self): - output = self.gen_outputs(lambda: self._build_ds(10), [], 100) + output = self._gen_outputs(lambda: self._build_ds(10), 100) self.assertSequenceEqual( sorted(output), sorted( np.array([range(20) for _ in range(5)]).flatten())) @@ -89,53 +53,53 @@ class ShuffleAndRepeatTest( def testReshuffling(self): # Check that the output orders of different epochs are indeed different. - output = self.gen_outputs(lambda: self._build_ds(10), [], 100) + output = self._gen_outputs(lambda: self._build_ds(10), 100) for i in range(4): epoch1 = output[i * 20:(i + 1) * 20] epoch2 = output[(i + 1) * 20:(i + 2) * 20] self.assertNotEqual(epoch1, epoch2) def testSameOrderForSameSeeds(self): - output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100) - output2 = self.gen_outputs(lambda: self._build_ds(10), [], 100) + output1 = self._gen_outputs(lambda: self._build_ds(10), 100) + output2 = self._gen_outputs(lambda: self._build_ds(10), 100) self.assertEqual(output1, output2) def testDifferentOrderForDifferentSeeds(self): - output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100) - output2 = self.gen_outputs(lambda: self._build_ds(20), [], 100) + output1 = self._gen_outputs(lambda: self._build_ds(10), 100) + output2 = self._gen_outputs(lambda: self._build_ds(20), 100) self.assertNotEqual(output1, output2) self.assertEqual(sorted(output1), sorted(output2)) def testCountNone(self): - output1 = self.gen_outputs( - lambda: self._build_ds(10, count=None), [], 100, verify_exhausted=False) - output2 = self.gen_outputs( - lambda: self._build_ds(20, count=None), [], 100, verify_exhausted=False) + output1 = self._gen_outputs( + lambda: self._build_ds(10, count=None), 100, verify_exhausted=False) + output2 = self._gen_outputs( + lambda: self._build_ds(20, count=None), 100, verify_exhausted=False) self.assertNotEqual(output1, output2) self.assertEqual(sorted(output1), sorted(output2)) def testCountMinusOne(self): - output1 = self.gen_outputs( - lambda: self._build_ds(10, count=-1), [], 100, verify_exhausted=False) - output2 = self.gen_outputs( - lambda: self._build_ds(20, count=-1), [], 100, verify_exhausted=False) + output1 = self._gen_outputs( + lambda: self._build_ds(10, count=-1), 100, verify_exhausted=False) + output2 = self._gen_outputs( + lambda: self._build_ds(20, count=-1), 100, verify_exhausted=False) self.assertNotEqual(output1, output2) self.assertEqual(sorted(output1), sorted(output2)) def testInfiniteOutputs(self): # Asserting the iterator is exhausted after producing 100 items should fail. with self.assertRaises(AssertionError): - self.gen_outputs(lambda: self._build_ds(10, count=None), [], 100) + self._gen_outputs(lambda: self._build_ds(10, count=None), 100) with self.assertRaises(AssertionError): - self.gen_outputs(lambda: self._build_ds(10, count=-1), [], 100) + self._gen_outputs(lambda: self._build_ds(10, count=-1), 100) def testInfiniteEmpty(self): with self.assertRaises(errors.OutOfRangeError): - self.gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0), - [], 100) + self._gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0), + 100) with self.assertRaises(errors.OutOfRangeError): - self.gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0), [], - 100) + self._gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0), + 100) def testLargeBufferSize(self): with ops.Graph().as_default() as g: @@ -146,17 +110,5 @@ class ShuffleAndRepeatTest( sess.run(get_next_op) -class ShuffleAndRepeatSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_ds(self, seed): - return dataset_ops.Dataset.range(20).apply( - shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed)) - - def testCore(self): - self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20), - 100) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py index 4148addf2878c99f47ebe1454edf69ad7f38dfbc..2c2cfbebff5d3eba00f120467102b4185d81ab24 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py @@ -18,83 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - -import sqlite3 - -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base -from tensorflow.contrib.data.python.ops import readers +from tensorflow.contrib.data.python.kernel_tests import sql_dataset_op_test_base from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class SqlDatasetTestBase(test.TestCase): - - def _createSqlDataset(self, output_types, num_repeats=1): - dataset = readers.SqlDataset(self.driver_name, self.data_source_name, - self.query, output_types).repeat(num_repeats) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - return init_op, get_next - - def setUp(self): - self.data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") - self.driver_name = array_ops.placeholder_with_default( - array_ops.constant("sqlite", dtypes.string), shape=[]) - self.query = array_ops.placeholder(dtypes.string, shape=[]) - - conn = sqlite3.connect(self.data_source_name) - c = conn.cursor() - c.execute("DROP TABLE IF EXISTS students") - c.execute("DROP TABLE IF EXISTS people") - c.execute("DROP TABLE IF EXISTS townspeople") - c.execute( - "CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY, " - "first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100), " - "school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), " - "desk_number INTEGER, income INTEGER, favorite_number INTEGER, " - "favorite_big_number INTEGER, favorite_negative_number INTEGER, " - "favorite_medium_sized_number INTEGER, brownie_points INTEGER, " - "account_balance INTEGER, registration_complete INTEGER)") - c.executemany( - "INSERT INTO students (first_name, last_name, motto, school_id, " - "favorite_nonsense_word, desk_number, income, favorite_number, " - "favorite_big_number, favorite_negative_number, " - "favorite_medium_sized_number, brownie_points, account_balance, " - "registration_complete) " - "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - [("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647, - 9223372036854775807, -2, 32767, 0, 0, 1), - ("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 127, -20000, - -2147483648, -9223372036854775808, -128, -32768, 255, 65535, 0)]) - c.execute( - "CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, " - "first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))") - c.executemany( - "INSERT INTO PEOPLE (first_name, last_name, state) VALUES (?, ?, ?)", - [("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe", - "California")]) - c.execute( - "CREATE TABLE IF NOT EXISTS townspeople (id INTEGER NOT NULL PRIMARY " - "KEY, first_name VARCHAR(100), last_name VARCHAR(100), victories " - "FLOAT, accolades FLOAT, triumphs FLOAT)") - c.executemany( - "INSERT INTO townspeople (first_name, last_name, victories, " - "accolades, triumphs) VALUES (?, ?, ?, ?, ?)", - [("George", "Washington", 20.00, - 1331241.321342132321324589798264627463827647382647382643874, - 9007199254740991.0), - ("John", "Adams", -19.95, - 1331241321342132321324589798264627463827647382647382643874.0, - 9007199254740992.0)]) - conn.commit() - conn.close() - - -class SqlDatasetTest(SqlDatasetTestBase): +class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase): # Test that SqlDataset can read from a database table. def testReadResultSet(self): @@ -656,27 +586,5 @@ class SqlDatasetTest(SqlDatasetTestBase): sess.run(get_next) -class SqlDatasetSerializationTest( - SqlDatasetTestBase, - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_dataset(self, num_repeats): - data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") - driver_name = array_ops.placeholder_with_default( - array_ops.constant("sqlite", dtypes.string), shape=[]) - query = ("SELECT first_name, last_name, motto FROM students ORDER BY " - "first_name DESC") - output_types = (dtypes.string, dtypes.string, dtypes.string) - return readers.SqlDataset(driver_name, data_source_name, query, - output_types).repeat(num_repeats) - - def testSQLSaveable(self): - num_repeats = 4 - num_outputs = num_repeats * 2 - self.run_core_tests(lambda: self._build_dataset(num_repeats), - lambda: self._build_dataset(num_repeats // 2), - num_outputs) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..1f5c725a9269e80311f3e73c51c28ab80e7c4815 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py @@ -0,0 +1,96 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for testing SqlDataset.""" + + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import sqlite3 + +from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class SqlDatasetTestBase(test.TestCase): + """Base class for setting up and testing SqlDataset.""" + + def _createSqlDataset(self, output_types, num_repeats=1): + dataset = readers.SqlDataset(self.driver_name, self.data_source_name, + self.query, output_types).repeat(num_repeats) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + return init_op, get_next + + def setUp(self): + self.data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") + self.driver_name = array_ops.placeholder_with_default( + array_ops.constant("sqlite", dtypes.string), shape=[]) + self.query = array_ops.placeholder(dtypes.string, shape=[]) + + conn = sqlite3.connect(self.data_source_name) + c = conn.cursor() + c.execute("DROP TABLE IF EXISTS students") + c.execute("DROP TABLE IF EXISTS people") + c.execute("DROP TABLE IF EXISTS townspeople") + c.execute( + "CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY, " + "first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100), " + "school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), " + "desk_number INTEGER, income INTEGER, favorite_number INTEGER, " + "favorite_big_number INTEGER, favorite_negative_number INTEGER, " + "favorite_medium_sized_number INTEGER, brownie_points INTEGER, " + "account_balance INTEGER, registration_complete INTEGER)") + c.executemany( + "INSERT INTO students (first_name, last_name, motto, school_id, " + "favorite_nonsense_word, desk_number, income, favorite_number, " + "favorite_big_number, favorite_negative_number, " + "favorite_medium_sized_number, brownie_points, account_balance, " + "registration_complete) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + [("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647, + 9223372036854775807, -2, 32767, 0, 0, 1), + ("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 127, -20000, + -2147483648, -9223372036854775808, -128, -32768, 255, 65535, 0)]) + c.execute( + "CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, " + "first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))") + c.executemany( + "INSERT INTO PEOPLE (first_name, last_name, state) VALUES (?, ?, ?)", + [("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe", + "California")]) + c.execute( + "CREATE TABLE IF NOT EXISTS townspeople (id INTEGER NOT NULL PRIMARY " + "KEY, first_name VARCHAR(100), last_name VARCHAR(100), victories " + "FLOAT, accolades FLOAT, triumphs FLOAT)") + c.executemany( + "INSERT INTO townspeople (first_name, last_name, victories, " + "accolades, triumphs) VALUES (?, ?, ?, ?, ?)", + [("George", "Washington", 20.00, + 1331241.321342132321324589798264627463827647382647382643874, + 9007199254740991.0), + ("John", "Adams", -19.95, + 1331241321342132321324589798264627463827647382647382643874.0, + 9007199254740992.0)]) + conn.commit() + conn.close() + + diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index 5c74ed6ae7210e8e22efb6e8fdb773397459ce1e..b4945685c1d1062bf416b73f1541f351adf45604 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -19,7 +19,7 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.core.framework import summary_pb2 from tensorflow.python.data.ops import dataset_ops @@ -29,7 +29,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class StatsDatasetTest(test.TestCase): +class StatsDatasetTestBase(test.TestCase): def _assertSummaryHasCount(self, summary_str, tag, expected_value): summary_proto = summary_pb2.Summary() @@ -49,6 +49,9 @@ class StatsDatasetTest(test.TestCase): return self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + +class StatsDatasetTest(StatsDatasetTestBase): + def testBytesProduced(self): stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(100).map( @@ -193,68 +196,44 @@ class StatsDatasetTest(test.TestCase): self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0) -class StatsDatasetSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_dataset_bytes_stats(self, num_elements): - return dataset_ops.Dataset.range(num_elements).map( - lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply( - stats_ops.bytes_produced_stats("bytes_produced")) - - def test_bytes_produced_stats_invalid_tag_shape(self): - with self.assertRaisesRegexp( - ValueError, 'Shape must be rank 0 but is rank 1'): - self.run_core_tests( - lambda: dataset_ops.Dataset.range(100).apply( - stats_ops.bytes_produced_stats(["bytes_produced"])), - None, 100) - - def testBytesStatsDatasetSaveableCore(self): - num_outputs = 100 - self.run_core_tests( - lambda: self._build_dataset_bytes_stats(num_outputs), - lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs) +class FeatureStatsDatasetTest( + StatsDatasetTestBase, + reader_dataset_ops_test_base.ReadBatchFeaturesTestBase): - def _build_dataset_latency_stats(self, num_elements, tag="record_latency"): - return dataset_ops.Dataset.range(num_elements).apply( - stats_ops.latency_stats(tag)) - - def _build_dataset_multiple_tags(self, - num_elements, - tag1="record_latency", - tag2="record_latency_2"): - return dataset_ops.Dataset.range(num_elements).apply( - stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2)) - - def test_latency_stats_invalid_tag_shape(self): - with self.assertRaisesRegexp( - ValueError, 'Shape must be rank 0 but is rank 1'): - self.run_core_tests( - lambda: dataset_ops.Dataset.range(100).apply( - stats_ops.latency_stats(["record_latency", "record_latency_2"])), - None, 100) - - def testLatencyStatsDatasetSaveableCore(self): - num_outputs = 100 - - self.run_core_tests( - lambda: self._build_dataset_latency_stats(num_outputs), - lambda: self._build_dataset_latency_stats(num_outputs // 10), - num_outputs) - - self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs), - None, num_outputs) + def testFeaturesStats(self): + num_epochs = 5 + total_records = num_epochs * self._num_records + batch_size = 2 + stats_aggregator = stats_ops.StatsAggregator() + dataset = self.make_batch_feature( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5, + drop_final_batch=True).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() - tag1 = "record_latency" - tag2 = "record_latency" - self.run_core_tests( - lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2), - None, num_outputs) + with self.test_session() as sess: + sess.run(iterator.initializer) + for _ in range(total_records // batch_size): + sess.run(next_element) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount( + sess.run(summary_t), "record_stats:features", total_records) + self._assertSummaryHasCount( + sess.run(summary_t), "record_stats:feature-values", total_records) + self._assertSummaryHasSum( + sess.run(summary_t), "record_stats:features", total_records * 3) + self._assertSummaryHasSum( + sess.run(summary_t), "record_stats:feature-values", + self._sum_keywords(1) * num_epochs + 2 * total_records) -# TODO(shivaniagrawal): Can not checkpoint input_pipeline with the -# transformation `stats_ops.set_stats_aggregator`, since we don't support -# serializing StatsAggregator yet. if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py index 9167cb3379bba5cb1ba76a96549395c45dca9e35..0486e2bce20e9dcf81dcb5ac49fe5b397e44bf0c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import threading +from absl.testing import parameterized import numpy as np from tensorflow.contrib.data.python.ops import threadpool @@ -30,9 +31,11 @@ from tensorflow.python.ops import script_ops from tensorflow.python.platform import test -class OverrideThreadpoolDatasetTest(test.TestCase): +class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase): - def testNumThreads(self): + @parameterized.parameters((1, None), (2, None), (4, None), (8, None), + (16, None), (4, -1), (4, 0), (4, 1), (4, 4)) + def testNumThreads(self, num_threads, max_intra_op_parallelism): def get_thread_id(_): # Python creates a dummy thread object to represent the current @@ -42,35 +45,35 @@ class OverrideThreadpoolDatasetTest(test.TestCase): # identifier that maps one-to-one with the underlying OS thread. return np.array(threading.current_thread().ident).astype(np.int64) - for num_threads in [1, 2, 4, 8, 16]: + dataset = ( + dataset_ops.Dataset.range(1000).map( + lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64), + num_parallel_calls=32).apply(unique.unique())) - dataset = ( - dataset_ops.Dataset.range(1000).map( - lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64), - num_parallel_calls=32).apply(unique.unique())) + dataset = threadpool.override_threadpool( + dataset, + threadpool.PrivateThreadPool( + num_threads, + max_intra_op_parallelism=max_intra_op_parallelism, + display_name="private_thread_pool_%d" % num_threads)) - dataset = threadpool.override_threadpool( - dataset, - threadpool.PrivateThreadPool( - num_threads, display_name="private_thread_pool_%d" % num_threads)) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() - iterator = dataset.make_initializable_iterator() - next_element = iterator.get_next() - - with self.test_session() as sess: - sess.run(iterator.initializer) - thread_ids = [] - try: - while True: - thread_ids.append(sess.run(next_element)) - except errors.OutOfRangeError: - pass - self.assertEqual(len(thread_ids), len(set(thread_ids))) - self.assertGreater(len(thread_ids), 0) - # NOTE(mrry): We don't control the thread pool scheduling, and - # so cannot guarantee that all of the threads in the pool will - # perform work. - self.assertLessEqual(len(thread_ids), num_threads) + with self.test_session() as sess: + sess.run(iterator.initializer) + thread_ids = [] + try: + while True: + thread_ids.append(sess.run(next_element)) + except errors.OutOfRangeError: + pass + self.assertEqual(len(thread_ids), len(set(thread_ids))) + self.assertGreater(len(thread_ids), 0) + # NOTE(mrry): We don't control the thread pool scheduling, and + # so cannot guarantee that all of the threads in the pool will + # perform work. + self.assertLessEqual(len(thread_ids), num_threads) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py index 3c436f7a0b45a13109960e87dd97ca56b10bb871..d79a842e7a5d816e2e6a52fc83acbd6b260cf64b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import unique from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes @@ -79,18 +78,5 @@ class UniqueDatasetTest(test.TestCase): ]) -class UniqueSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def testUnique(self): - - def build_dataset(num_elements, unique_elem_range): - return dataset_ops.Dataset.range(num_elements).map( - lambda x: x % unique_elem_range).apply(unique.unique()) - - self.run_core_tests(lambda: build_dataset(200, 100), - lambda: build_dataset(40, 100), 100) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 086661adb7603345be09a4c710d4fb2b170ac8f9..02408145625b7e751541e7b87dc4fd5da4f7cad9 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -49,26 +49,6 @@ py_library( ], ) -py_test( - name = "iterator_ops_test", - size = "small", - srcs = ["iterator_ops_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - ":iterator_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", - "//tensorflow/python:variables", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/estimator", - "//tensorflow/python/estimator:model_fn", - ], -) - py_library( name = "random_ops", srcs = [ @@ -96,8 +76,10 @@ py_library( srcs_version = "PY2AND3", deps = [ ":batching", + ":gen_dataset_ops", ":interleave_ops", ":shuffle_ops", + ":stats_ops", "//tensorflow/python:constant_op", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", @@ -106,12 +88,12 @@ py_library( "//tensorflow/python:math_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", - "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:readers", + "//tensorflow/python/data/util:convert", "//tensorflow/python/data/util:nest", "//third_party/py/numpy", ], @@ -142,6 +124,7 @@ py_library( "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:convert", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", ], diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index b9393de4e90ae2597045b29070934b94e18cfcbd..5708d47c2081976f82722018adf30523c091416a 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.contrib.framework import with_shape from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import convert from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes @@ -29,6 +30,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util import deprecation def dense_to_sparse_batch(batch_size, row_shape): @@ -75,17 +77,17 @@ def dense_to_sparse_batch(batch_size, row_shape): """ def _apply_fn(dataset): - return DenseToSparseBatchDataset(dataset, batch_size, row_shape) + return _DenseToSparseBatchDataset(dataset, batch_size, row_shape) return _apply_fn -class UnbatchDataset(dataset_ops.Dataset): +class _UnbatchDataset(dataset_ops.Dataset): """A dataset that splits the elements of its input into multiple elements.""" def __init__(self, input_dataset): """See `unbatch()` for more details.""" - super(UnbatchDataset, self).__init__() + super(_UnbatchDataset, self).__init__() flat_shapes = nest.flatten(input_dataset.output_shapes) if any(s.ndims == 0 for s in flat_shapes): raise ValueError("Cannot unbatch an input with scalar components.") @@ -101,10 +103,7 @@ class UnbatchDataset(dataset_ops.Dataset): def _as_variant_tensor(self): return gen_dataset_ops.unbatch_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): @@ -145,7 +144,7 @@ def unbatch(): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" if not sparse.any_sparse(dataset.output_classes): - return UnbatchDataset(dataset) + return _UnbatchDataset(dataset) # NOTE(mrry): We must ensure that any SparseTensors in `dataset` # are normalized to the rank-1 dense representation, so that the @@ -171,7 +170,7 @@ def unbatch(): dataset.output_shapes, dataset.output_classes, allow_unsafe_cast=True) - return UnbatchDataset(restructured_dataset) + return _UnbatchDataset(restructured_dataset) return _apply_fn @@ -218,6 +217,8 @@ def filter_irregular_batches(batch_size): return _apply_fn +@deprecation.deprecated( + None, "Use `tf.data.Dataset.batch(..., drop_remainder=True)`.") def batch_and_drop_remainder(batch_size): """A batching transformation that omits the final small batch (if present). @@ -250,12 +251,16 @@ def batch_and_drop_remainder(batch_size): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" + # TODO(jsimsa): Switch to using `batch(..., drop_remainder=True)` any time + # after 6/30/2018. batched = dataset.batch(batch_size) return filter_irregular_batches(batch_size)(batched) return _apply_fn +@deprecation.deprecated( + None, "Use `tf.data.Dataset.padded_batch(..., drop_remainder=True)`.") def padded_batch_and_drop_remainder(batch_size, padded_shapes, padding_values=None): @@ -284,6 +289,8 @@ def padded_batch_and_drop_remainder(batch_size, def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" + # TODO(jsimsa): Switch to using `padded_batch(..., drop_remainder=True)` + # any time after 6/30/2018. batched = dataset.padded_batch( batch_size, padded_shapes=padded_shapes, padding_values=padding_values) return filter_irregular_batches(batch_size)(batched) @@ -291,12 +298,12 @@ def padded_batch_and_drop_remainder(batch_size, return _apply_fn -class DenseToSparseBatchDataset(dataset_ops.Dataset): +class _DenseToSparseBatchDataset(dataset_ops.Dataset): """A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s.""" def __init__(self, input_dataset, batch_size, row_shape): """See `Dataset.dense_to_sparse_batch()` for more details.""" - super(DenseToSparseBatchDataset, self).__init__() + super(_DenseToSparseBatchDataset, self).__init__() if not isinstance(input_dataset.output_types, dtypes.DType): raise TypeError("DenseToSparseDataset requires an input whose elements " "have a single component, whereas the input has %r." % @@ -309,11 +316,8 @@ class DenseToSparseBatchDataset(dataset_ops.Dataset): return gen_dataset_ops.dense_to_sparse_batch_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._batch_size, - row_shape=dataset_ops._partial_shape_to_tensor(self._row_shape), # pylint: disable=protected-access - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + row_shape=convert.partial_shape_to_tensor(self._row_shape), + **dataset_ops.flat_structure(self)) @property def output_classes(self): @@ -490,10 +494,7 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): batch_size=self._batch_size_t, num_parallel_calls=self._num_parallel_calls_t, drop_remainder=self._drop_remainder_t, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) # pylint: enable=protected-access @property diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index 6c21e489f7c35484ebacd465e3b46d6920df5933..d46d96c461ad4cc0ac25a8ddc285cec23d09c682 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -20,8 +20,6 @@ from __future__ import print_function from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse def ignore_errors(): @@ -48,26 +46,23 @@ def ignore_errors(): """ def _apply_fn(dataset): - return IgnoreErrorsDataset(dataset) + return _IgnoreErrorsDataset(dataset) return _apply_fn -class IgnoreErrorsDataset(dataset_ops.Dataset): +class _IgnoreErrorsDataset(dataset_ops.Dataset): """A `Dataset` that silently ignores errors when computing its input.""" def __init__(self, input_dataset): """See `Dataset.ignore_errors()` for details.""" - super(IgnoreErrorsDataset, self).__init__() + super(_IgnoreErrorsDataset, self).__init__() self._input_dataset = input_dataset def _as_variant_tensor(self): return gen_dataset_ops.ignore_errors_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/data/python/ops/get_single_element.py b/tensorflow/contrib/data/python/ops/get_single_element.py index 3a07df572748e464284f580d67e3a664e71acdfe..0f4cd8e20c5727a5bcfa1dce4dadbfa8f90bd551 100644 --- a/tensorflow/contrib/data/python/ops/get_single_element.py +++ b/tensorflow/contrib/data/python/ops/get_single_element.py @@ -64,10 +64,7 @@ def get_single_element(dataset): nested_ret = nest.pack_sequence_as( dataset.output_types, gen_dataset_ops.dataset_to_single_element( dataset._as_variant_tensor(), # pylint: disable=protected-access - output_types=nest.flatten(sparse.as_dense_types( - dataset.output_types, dataset.output_classes)), - output_shapes=nest.flatten(sparse.as_dense_shapes( - dataset.output_shapes, dataset.output_classes)))) + **dataset_ops.flat_structure(dataset))) return sparse.deserialize_sparse_tensors( nested_ret, dataset.output_types, dataset.output_shapes, dataset.output_classes) diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index ea229b5b27b117984e508fa4edc6f1cf713008b4..348884e9fab98e5c8a04ca436253d94de7931c8b 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -21,12 +21,9 @@ import numpy as np from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -58,7 +55,7 @@ def group_by_reducer(key_func, reducer): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - return GroupByReducerDataset(dataset, key_func, reducer) + return _GroupByReducerDataset(dataset, key_func, reducer) return _apply_fn @@ -116,8 +113,8 @@ def group_by_window(key_func, def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - return GroupByWindowDataset(dataset, key_func, reduce_func, - window_size_func) + return _GroupByWindowDataset(dataset, key_func, reduce_func, + window_size_func) return _apply_fn @@ -257,12 +254,12 @@ class _VariantDataset(dataset_ops.Dataset): return self._output_types -class GroupByReducerDataset(dataset_ops.Dataset): +class _GroupByReducerDataset(dataset_ops.Dataset): """A `Dataset` that groups its input and performs a reduction.""" def __init__(self, input_dataset, key_func, reducer): """See `group_by_reducer()` for details.""" - super(GroupByReducerDataset, self).__init__() + super(_GroupByReducerDataset, self).__init__() self._input_dataset = input_dataset @@ -273,67 +270,27 @@ class GroupByReducerDataset(dataset_ops.Dataset): def _make_key_func(self, key_func, input_dataset): """Make wrapping Defun for key_func.""" - - @function.Defun(*nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes))) - def tf_key_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the input_dataset. - dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, input_dataset.output_types, input_dataset.output_shapes, - input_dataset.output_classes) - # pylint: disable=protected-access - if dataset_ops._should_unpack_args(nested_args): - ret = key_func(*nested_args) - # pylint: enable=protected-access - else: - ret = key_func(nested_args) - ret = ops.convert_to_tensor(ret) - if ret.dtype != dtypes.int64 or ret.get_shape() != tensor_shape.scalar(): - raise ValueError( - "`key_func` must return a single tf.int64 tensor. " - "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape())) - return ret - - self._key_func = tf_key_func - self._key_func.add_to_graph(ops.get_default_graph()) + wrapped_func = dataset_ops.StructuredFunctionWrapper( + key_func, "tf.contrib.data.group_by_reducer()", input_dataset) + if not ( + wrapped_func.output_types == dtypes.int64 and + wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): + raise ValueError( + "`key_func` must return a single tf.int64 tensor. " + "Got type=%s and shape=%s" + % (wrapped_func.output_types, wrapped_func.output_shapes)) + self._key_func = wrapped_func.function def _make_init_func(self, init_func): """Make wrapping Defun for init_func.""" - - @function.Defun(dtypes.int64) - def tf_init_func(key): - """A wrapper for Defun that facilitates shape inference.""" - key.set_shape([]) - ret = init_func(key) - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - ret = nest.pack_sequence_as(ret, [ - sparse_tensor.SparseTensor.from_value(t) - if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) - for t in nest.flatten(ret) - ]) - - self._state_classes = sparse.get_classes(ret) - self._state_shapes = nest.pack_sequence_as( - ret, [t.get_shape() for t in nest.flatten(ret)]) - self._state_types = nest.pack_sequence_as( - ret, [t.dtype for t in nest.flatten(ret)]) - - # Serialize any sparse tensors. - ret = nest.pack_sequence_as( - ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) - return nest.flatten(ret) - - self._init_func = tf_init_func - self._init_func.add_to_graph(ops.get_default_graph()) + wrapped_func = dataset_ops.StructuredFunctionWrapper( + init_func, "tf.contrib.data.group_by_reducer()", + input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(), + input_types=dtypes.int64) + self._init_func = wrapped_func.function + self._state_classes = wrapped_func.output_classes + self._state_shapes = wrapped_func.output_shapes + self._state_types = wrapped_func.output_types def _make_reduce_func(self, reduce_func, input_dataset): """Make wrapping Defun for reduce_func.""" @@ -343,83 +300,47 @@ class GroupByReducerDataset(dataset_ops.Dataset): need_to_rerun = True while need_to_rerun: - # Create a list in which `tf_reduce_func` will store the new shapes. - flat_new_state_shapes = [] - - @function.Defun(*(nest.flatten( - sparse.as_dense_types( - self._state_types, self._state_classes)) + nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes)))) - def tf_reduce_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - for arg, shape in zip( - args, - nest.flatten( - sparse.as_dense_shapes(self._state_shapes, self._state_classes)) - + nest.flatten( - sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes))): - arg.set_shape(shape) - - pivot = len(nest.flatten(self._state_shapes)) - nested_state_args = nest.pack_sequence_as(self._state_types, - args[:pivot]) - nested_state_args = sparse.deserialize_sparse_tensors( - nested_state_args, self._state_types, self._state_shapes, - self._state_classes) - nested_input_args = nest.pack_sequence_as(input_dataset.output_types, - args[pivot:]) - nested_input_args = sparse.deserialize_sparse_tensors( - nested_input_args, input_dataset.output_types, - input_dataset.output_shapes, input_dataset.output_classes) - - ret = reduce_func(nested_state_args, nested_input_args) - - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - ret = nest.pack_sequence_as(ret, [ - sparse_tensor.SparseTensor.from_value(t) - if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) - for t in nest.flatten(ret) - ]) - - # Extract shape information from the returned values. - flat_new_state = nest.flatten(ret) - flat_new_state_shapes.extend([t.get_shape() for t in flat_new_state]) - - # Extract and validate type information from the returned values. - for t, dtype in zip(flat_new_state, nest.flatten(self._state_types)): - if t.dtype != dtype: - raise TypeError( - "The element types for the new state must match the initial " - "state. Expected %s; got %s." % - (self._state_types, - nest.pack_sequence_as(self._state_types, - [t.dtype for t in flat_new_state]))) - - # Serialize any sparse tensors. - ret = nest.pack_sequence_as( - ret, - [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) - return nest.flatten(ret) - - # Use the private method that will execute `tf_reduce_func` but delay - # adding it to the graph in case we need to rerun the function. - tf_reduce_func._create_definition_if_needed() # pylint: disable=protected-access - + wrapped_func = dataset_ops.StructuredFunctionWrapper( + reduce_func, "tf.contrib.data.group_by_reducer()", + input_classes=(self._state_classes, input_dataset.output_classes), + input_shapes=(self._state_shapes, input_dataset.output_shapes), + input_types=(self._state_types, input_dataset.output_types), + add_to_graph=False) + + # Extract and validate class information from the returned values. + for new_state_class, state_class in zip( + nest.flatten(wrapped_func.output_classes), + nest.flatten(self._state_classes)): + if not issubclass(new_state_class, state_class): + raise TypeError( + "The element classes for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_classes, wrapped_func.output_classes)) + + # Extract and validate type information from the returned values. + for new_state_type, state_type in zip( + nest.flatten(wrapped_func.output_types), + nest.flatten(self._state_types)): + if new_state_type != state_type: + raise TypeError( + "The element types for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_types, wrapped_func.output_types)) + + # Extract shape information from the returned values. flat_state_shapes = nest.flatten(self._state_shapes) + flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes) weakened_state_shapes = [ - old.most_specific_compatible_shape(new) - for old, new in zip(flat_state_shapes, flat_new_state_shapes) + original.most_specific_compatible_shape(new) + for original, new in zip(flat_state_shapes, flat_new_state_shapes) ] need_to_rerun = False - for old_shape, weakened_shape in zip(flat_state_shapes, - weakened_state_shapes): - if old_shape.ndims is not None and ( + for original_shape, weakened_shape in zip(flat_state_shapes, + weakened_state_shapes): + if original_shape.ndims is not None and ( weakened_shape.ndims is None or - old_shape.as_list() != weakened_shape.as_list()): + original_shape.as_list() != weakened_shape.as_list()): need_to_rerun = True break @@ -427,50 +348,19 @@ class GroupByReducerDataset(dataset_ops.Dataset): self._state_shapes = nest.pack_sequence_as(self._state_shapes, weakened_state_shapes) - self._reduce_func = tf_reduce_func + self._reduce_func = wrapped_func.function self._reduce_func.add_to_graph(ops.get_default_graph()) def _make_finalize_func(self, finalize_func): """Make wrapping Defun for finalize_func.""" - - @function.Defun(*(nest.flatten( - sparse.as_dense_types(self._state_types, self._state_classes)))) - def tf_finalize_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - for arg, shape in zip( - args, - nest.flatten( - sparse.as_dense_shapes(self._state_shapes, self._state_classes))): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(self._state_types, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, self._state_types, self._state_shapes, - self._state_classes) - - ret = finalize_func(nested_args) - - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - ret = nest.pack_sequence_as(ret, [ - sparse_tensor.SparseTensor.from_value(t) - if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) - for t in nest.flatten(ret) - ]) - - self._output_classes = sparse.get_classes(ret) - self._output_shapes = nest.pack_sequence_as( - ret, [t.get_shape() for t in nest.flatten(ret)]) - self._output_types = nest.pack_sequence_as( - ret, [t.dtype for t in nest.flatten(ret)]) - - # Serialize any sparse tensors. - ret = nest.pack_sequence_as( - ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) - return nest.flatten(ret) - - self._finalize_func = tf_finalize_func - self._finalize_func.add_to_graph(ops.get_default_graph()) + wrapped_func = dataset_ops.StructuredFunctionWrapper( + finalize_func, "tf.contrib.data.group_by_reducer()", + input_classes=self._state_classes, input_shapes=self._state_shapes, + input_types=self._state_types) + self._finalize_func = wrapped_func.function + self._output_classes = wrapped_func.output_classes + self._output_shapes = wrapped_func.output_shapes + self._output_types = wrapped_func.output_types @property def output_classes(self): @@ -495,18 +385,15 @@ class GroupByReducerDataset(dataset_ops.Dataset): init_func=self._init_func, reduce_func=self._reduce_func, finalize_func=self._finalize_func, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) -class GroupByWindowDataset(dataset_ops.Dataset): +class _GroupByWindowDataset(dataset_ops.Dataset): """A `Dataset` that groups its input and performs a windowed reduction.""" def __init__(self, input_dataset, key_func, reduce_func, window_size_func): """See `group_by_window()` for details.""" - super(GroupByWindowDataset, self).__init__() + super(_GroupByWindowDataset, self).__init__() self._input_dataset = input_dataset @@ -516,64 +403,39 @@ class GroupByWindowDataset(dataset_ops.Dataset): def _make_window_size_func(self, window_size_func): """Make wrapping Defun for window_size_func.""" - - @function.Defun(dtypes.int64) - def tf_window_size_func(key): - key.set_shape([]) - window_size = ops.convert_to_tensor( - window_size_func(key), dtype=dtypes.int64) - if window_size.dtype != dtypes.int64: - raise ValueError( - "`window_size_func` must return a single tf.int64 tensor.") - return window_size - - self._window_size_func = tf_window_size_func - self._window_size_func.add_to_graph(ops.get_default_graph()) + def window_size_func_wrapper(key): + return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64) + wrapped_func = dataset_ops.StructuredFunctionWrapper( + window_size_func_wrapper, "tf.contrib.data.group_by_window()", + input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(), + input_types=dtypes.int64) + if not ( + wrapped_func.output_types == dtypes.int64 and + wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): + raise ValueError( + "`window_size_func` must return a single tf.int64 scalar tensor.") + self._window_size_func = wrapped_func.function def _make_key_func(self, key_func, input_dataset): """Make wrapping Defun for key_func.""" - - @function.Defun(*nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes))) - def tf_key_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the input_dataset. - dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes) - for arg, shape in zip(args, nest.flatten(dense_shapes)): - arg.set_shape(shape) - - nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - nested_args = sparse.deserialize_sparse_tensors( - nested_args, input_dataset.output_types, input_dataset.output_shapes, - input_dataset.output_classes) - # pylint: disable=protected-access - if dataset_ops._should_unpack_args(nested_args): - ret = key_func(*nested_args) - # pylint: enable=protected-access - else: - ret = key_func(nested_args) - ret = ops.convert_to_tensor(ret, dtype=dtypes.int64) - if ret.dtype != dtypes.int64: - raise ValueError("`key_func` must return a single tf.int64 tensor.") - return ret - - self._key_func = tf_key_func - self._key_func.add_to_graph(ops.get_default_graph()) + def key_func_wrapper(*args): + return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64) + wrapped_func = dataset_ops.StructuredFunctionWrapper( + key_func_wrapper, "tf.contrib.data.group_by_window()", input_dataset) + if not ( + wrapped_func.output_types == dtypes.int64 and + wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())): + raise ValueError( + "`key_func` must return a single tf.int64 scalar tensor.") + self._key_func = wrapped_func.function def _make_reduce_func(self, reduce_func, input_dataset): """Make wrapping Defun for reduce_func.""" - - @function.Defun(dtypes.int64, dtypes.variant) - def tf_reduce_func(key, window_dataset_variant): - """A wrapper for Defun that facilitates shape inference.""" - key.set_shape([]) + def reduce_func_wrapper(key, window_dataset_variant): + """Wrapper that converts between tf.variant and Dataset objects.""" window_dataset = _VariantDataset( window_dataset_variant, input_dataset.output_types, input_dataset.output_shapes, input_dataset.output_classes) - if not isinstance(window_dataset, dataset_ops.Dataset): - raise TypeError("`window_dataset` must return a `Dataset` object.") output_dataset = reduce_func(key, window_dataset) if not isinstance(output_dataset, dataset_ops.Dataset): raise TypeError("`reduce_func` must return a `Dataset` object.") @@ -582,8 +444,12 @@ class GroupByWindowDataset(dataset_ops.Dataset): self._output_shapes = output_dataset.output_shapes return output_dataset._as_variant_tensor() # pylint: disable=protected-access - self._reduce_func = tf_reduce_func - self._reduce_func.add_to_graph(ops.get_default_graph()) + wrapped_func = dataset_ops.StructuredFunctionWrapper( + reduce_func_wrapper, "tf.contrib.data.reduce_by_window()", + input_classes=(ops.Tensor, ops.Tensor), + input_shapes=(tensor_shape.scalar(), tensor_shape.scalar()), + input_types=(dtypes.int64, dtypes.variant)) + self._reduce_func = wrapped_func.function @property def output_classes(self): @@ -606,10 +472,7 @@ class GroupByWindowDataset(dataset_ops.Dataset): key_func=self._key_func, reduce_func=self._reduce_func, window_size_func=self._window_size_func, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) class Reducer(object): diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index be66fbac50753c8f54b62dd615ee60804f4cf20d..bcc959594a6b311a3c60bb4696ac97be5c448756 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -24,7 +24,6 @@ from tensorflow.contrib.data.python.ops import random_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -154,7 +153,7 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): return _apply_fn -class DirectedInterleaveDataset(dataset_ops.Dataset): +class _DirectedInterleaveDataset(dataset_ops.Dataset): """A substitute for `Dataset.interleave()` on a fixed list of datasets.""" def __init__(self, selector_input, data_inputs): @@ -171,10 +170,7 @@ class DirectedInterleaveDataset(dataset_ops.Dataset): return gen_dataset_ops.directed_interleave_dataset( self._selector_input._as_variant_tensor(), [data_input._as_variant_tensor() for data_input in self._data_inputs], - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) # pylint: enable=protected-access @property @@ -240,7 +236,7 @@ def sample_from_datasets(datasets, weights=None, seed=None): selector_input = dataset_ops.Dataset.zip( (logits_ds, random_ops.RandomDataset(seed).batch(2))).map(select_dataset) - return DirectedInterleaveDataset(selector_input, datasets) + return _DirectedInterleaveDataset(selector_input, datasets) def choose_from_datasets(datasets, choice_dataset): @@ -284,4 +280,4 @@ def choose_from_datasets(datasets, choice_dataset): and choice_dataset.output_classes == ops.Tensor): raise TypeError("`choice_dataset` must be a dataset of scalar " "`tf.int64` tensors.") - return DirectedInterleaveDataset(choice_dataset, datasets) + return _DirectedInterleaveDataset(choice_dataset, datasets) diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py index cad41bce2961f29a7591fe3d382d1ab35a6b38b4..cf896572262929add5ac34d4fc8e4192c1049da3 100644 --- a/tensorflow/contrib/data/python/ops/optimization.py +++ b/tensorflow/contrib/data/python/ops/optimization.py @@ -19,8 +19,6 @@ from __future__ import print_function from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops @@ -41,17 +39,17 @@ def optimize(optimizations=None): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - return OptimizeDataset(dataset, optimizations) + return _OptimizeDataset(dataset, optimizations) return _apply_fn -class OptimizeDataset(dataset_ops.Dataset): +class _OptimizeDataset(dataset_ops.Dataset): """A `Dataset` that acts as an identity, and applies optimizations.""" def __init__(self, input_dataset, optimizations): """See `optimize()` for details.""" - super(OptimizeDataset, self).__init__() + super(_OptimizeDataset, self).__init__() self._input_dataset = input_dataset if optimizations is None: optimizations = [] @@ -62,10 +60,7 @@ class OptimizeDataset(dataset_ops.Dataset): return gen_dataset_ops.optimize_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._optimizations, - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py index 28ef5e50f39dd7d1b6f124e58e068fc968ddd6dc..e670c4c8354f4067eb21c9b1fce708147c162967 100644 --- a/tensorflow/contrib/data/python/ops/random_ops.py +++ b/tensorflow/contrib/data/python/ops/random_ops.py @@ -18,9 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest from tensorflow.python.data.util import random_seed -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -39,10 +37,7 @@ class RandomDataset(dataset_ops.Dataset): return gen_dataset_ops.random_dataset( seed=self._seed, seed2=self._seed2, - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index f938153f5f8c8becc5877a667117fd6facd3e428..83095c7ba1c6465d18490e5197f71bf7f1fe2497 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -26,6 +26,7 @@ from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.contrib.data.python.ops import shuffle_ops +from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.data.util import convert @@ -754,6 +755,8 @@ def make_batched_features_dataset(file_pattern, dataset = _maybe_shuffle_and_repeat( dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) + dataset = dataset.apply(stats_ops.feature_stats("record_stats")) + if drop_final_batch: dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size)) else: diff --git a/tensorflow/contrib/data/python/ops/resampling.py b/tensorflow/contrib/data/python/ops/resampling.py index bad6edd5147d832228c412919f1e6e782aafc40f..182a5c6ff36fcda8c9e2c522cce07bed0c2daec9 100644 --- a/tensorflow/contrib/data/python/ops/resampling.py +++ b/tensorflow/contrib/data/python/ops/resampling.py @@ -291,4 +291,4 @@ def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs): # TODO(joelshor): Simplify fraction, if possible. a_i = (ratio_l - m) / (max_ratio - m) - return a_i, m \ No newline at end of file + return a_i, m diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py index e911ad0fa0541f2d8b991d66182dd002c2ecaab0..ea9dcfe68fa2630d915323fa295031af7d48cdfb 100644 --- a/tensorflow/contrib/data/python/ops/scan_ops.py +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -22,7 +22,6 @@ import collections from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse -from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import gen_dataset_ops @@ -67,102 +66,45 @@ class _ScanDataset(dataset_ops.Dataset): need_to_rerun = True while need_to_rerun: - # Create a list in which `tf_scan_func` will store the new shapes. - flat_new_state_shapes = [] - - @function.Defun(*(nest.flatten( - sparse.as_dense_types( - self._state_types, self._state_classes)) + nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes)))) - def tf_scan_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the state and input_dataset. - for arg, shape in zip( - args, - nest.flatten( - sparse.as_dense_shapes(self._state_shapes, self._state_classes)) - + nest.flatten( - sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes))): - arg.set_shape(shape) - - pivot = len(nest.flatten(self._state_shapes)) - print(self._state_classes) - nested_state_args = nest.pack_sequence_as(self._state_types, - args[:pivot]) - nested_state_args = sparse.deserialize_sparse_tensors( - nested_state_args, self._state_types, self._state_shapes, - self._state_classes) - print(input_dataset.output_classes) - nested_input_args = nest.pack_sequence_as(input_dataset.output_types, - args[pivot:]) - nested_input_args = sparse.deserialize_sparse_tensors( - nested_input_args, input_dataset.output_types, - input_dataset.output_shapes, input_dataset.output_classes) - - ret = scan_func(nested_state_args, nested_input_args) - if not isinstance(ret, collections.Sequence) or len(ret) != 2: - raise TypeError("The scan function must return a pair comprising the " - "new state and the output value.") - - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - ret = nest.pack_sequence_as(ret, [ - sparse_tensor.SparseTensor.from_value(t) - if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) - for t in nest.flatten(ret) - ]) - new_state, output_value = ret - - # Extract and validate class information from the returned values. - for t, clazz in zip( - nest.flatten(new_state), nest.flatten(self._state_classes)): - if not isinstance(t, clazz): - raise TypeError( - "The element classes for the new state must match the initial " - "state. Expected %s; got %s." % - (self._state_classes, - nest.pack_sequence_as( - self._state_types, - [type(t) for t in nest.flatten(new_state)]))) - self._output_classes = sparse.get_classes(output_value) - - # Extract shape information from the returned values. - flat_new_state_shapes.extend( - [t.get_shape() for t in nest.flatten(new_state)]) - self._output_shapes = nest.pack_sequence_as( - output_value, [t.get_shape() for t in nest.flatten(output_value)]) - - # Extract and validate type information from the returned values. - for t, dtype in zip( - nest.flatten(new_state), nest.flatten(self._state_types)): - if t.dtype != dtype: - raise TypeError( - "The element types for the new state must match the initial " - "state. Expected %s; got %s." % - (self._state_types, - nest.pack_sequence_as( - self._state_types, - [t.dtype for t in nest.flatten(new_state)]))) - self._output_types = nest.pack_sequence_as( - output_value, [t.dtype for t in nest.flatten(output_value)]) - - # Serialize any sparse tensors. - new_state = nest.pack_sequence_as(new_state, [ - t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state)) - ]) - output_value = nest.pack_sequence_as(output_value, [ - t for t in nest.flatten( - sparse.serialize_sparse_tensors(output_value)) - ]) - return nest.flatten(new_state) + nest.flatten(output_value) - - # Use the private method that will execute `tf_scan_func` but delay - # adding it to the graph in case we need to rerun the function. - tf_scan_func._create_definition_if_needed() # pylint: disable=protected-access + wrapped_func = dataset_ops.StructuredFunctionWrapper( + scan_func, "tf.contrib.data.scan()", + input_classes=(self._state_classes, input_dataset.output_classes), + input_shapes=(self._state_shapes, input_dataset.output_shapes), + input_types=(self._state_types, input_dataset.output_types), + add_to_graph=False) + if not ( + isinstance(wrapped_func.output_types, collections.Sequence) and + len(wrapped_func.output_types) == 2): + raise TypeError("The scan function must return a pair comprising the " + "new state and the output value.") + + new_state_classes, self._output_classes = wrapped_func.output_classes + + # Extract and validate class information from the returned values. + for new_state_class, state_class in zip( + nest.flatten(new_state_classes), + nest.flatten(self._state_classes)): + if not issubclass(new_state_class, state_class): + raise TypeError( + "The element classes for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_classes, new_state_classes)) + + # Extract and validate type information from the returned values. + new_state_types, self._output_types = wrapped_func.output_types + for new_state_type, state_type in zip( + nest.flatten(new_state_types), nest.flatten(self._state_types)): + if new_state_type != state_type: + raise TypeError( + "The element types for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_types, new_state_types)) + + # Extract shape information from the returned values. + new_state_shapes, self._output_shapes = wrapped_func.output_shapes flat_state_shapes = nest.flatten(self._state_shapes) + flat_new_state_shapes = nest.flatten(new_state_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) @@ -178,12 +120,10 @@ class _ScanDataset(dataset_ops.Dataset): break if need_to_rerun: - # NOTE(mrry): `self._output_shapes` will be overwritten when we rerun - # `tf_scan_func`. self._state_shapes = nest.pack_sequence_as(self._state_shapes, weakened_state_shapes) - self._scan_func = tf_scan_func + self._scan_func = wrapped_func.function self._scan_func.add_to_graph(ops.get_default_graph()) def _as_variant_tensor(self): @@ -193,10 +133,7 @@ class _ScanDataset(dataset_ops.Dataset): nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)), self._scan_func.captured_inputs, f=self._scan_func, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py index f35795abd38000b13cec0f08596e2ff66e86286c..d7f8a73fe3d67bb83e44e962832ce34c116aef66 100644 --- a/tensorflow/contrib/data/python/ops/shuffle_ops.py +++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py @@ -18,9 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest from tensorflow.python.data.util import random_seed -from tensorflow.python.data.util import sparse from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -56,10 +54,7 @@ class _ShuffleAndRepeatDataset(dataset_ops.Dataset): count=self._count, seed=self._seed, seed2=self._seed2, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) # pylint: enable=protected-access @property diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py index 19cc3cb89fc5c494f79ce1d25ed57c92099c8bd2..f935beb1a9e85d4901857e7781a5ed8473838fa5 100644 --- a/tensorflow/contrib/data/python/ops/sliding.py +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -19,7 +19,6 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -43,10 +42,7 @@ class _SlideDataset(dataset_ops.Dataset): self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access window_size=self._window_size, stride=self._stride, - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py index 3cbaab5affd7397213b0fbb6b0682db92b99d591..97931f75bd37d9e45864fe477c6e1620b5e4f193 100644 --- a/tensorflow/contrib/data/python/ops/stats_ops.py +++ b/tensorflow/contrib/data/python/ops/stats_ops.py @@ -18,13 +18,13 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. class StatsAggregator(object): """A stateful resource that aggregates statistics from one or more iterators. @@ -97,10 +97,7 @@ class _SetStatsAggregatorDataset(dataset_ops.Dataset): return gen_dataset_ops.set_stats_aggregator_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._stats_aggregator._resource, # pylint: disable=protected-access - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_shapes(self): @@ -115,7 +112,8 @@ class _SetStatsAggregatorDataset(dataset_ops.Dataset): return self._input_dataset.output_classes -# TODO(shivaniagrawal): Expose these methods in `tf.contrib.data`. +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. def set_stats_aggregator(stats_aggregator): """Set the given stats_aggregator for aggregating the input dataset stats. @@ -133,6 +131,8 @@ def set_stats_aggregator(stats_aggregator): return _apply_fn +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. def bytes_produced_stats(tag): """Records the number of bytes produced by each element of the input dataset. @@ -155,6 +155,8 @@ def bytes_produced_stats(tag): return _apply_fn +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. def latency_stats(tag): """Records the latency of producing each element of the input dataset. @@ -176,6 +178,29 @@ def latency_stats(tag): return _apply_fn +# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. +def feature_stats(tag): + """Records the features stats from `Example` records of the input dataset. + + To consume the statistics, associate a `StatsAggregator` with the output + dataset. + + Args: + tag: String. All statistics recorded by the returned transformation will be + associated with the given `tag`. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return _StatsDataset(dataset, gen_dataset_ops.feature_stats_dataset, tag) + + return _apply_fn + + class _StatsDataset(dataset_ops.Dataset): """A `Dataset` that acts as an identity, and also records statistics.""" @@ -189,10 +214,7 @@ class _StatsDataset(dataset_ops.Dataset): return self._op_function( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._tag, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_shapes(self): diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py index 56f67e1766bbaff680bdff6b939df0c3ba68c679..9af1e784ffb4f6d71da25f09d60343b649c5079b 100644 --- a/tensorflow/contrib/data/python/ops/threadpool.py +++ b/tensorflow/contrib/data/python/ops/threadpool.py @@ -22,8 +22,6 @@ import threading from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.eager import context from tensorflow.python.ops import resource_variable_ops @@ -39,22 +37,28 @@ def _generate_shared_name(prefix): return "{}{}".format(prefix, uid) +# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. class PrivateThreadPool(object): """A stateful resource that represents a private thread pool.""" - def __init__(self, num_threads, display_name=None): + def __init__(self, num_threads, display_name=None, + max_intra_op_parallelism=1): """Creates a `PrivateThreadPool` with the given number of threads.""" if context.executing_eagerly(): shared_name = _generate_shared_name("privatethreadpool") self._resource = gen_dataset_ops.thread_pool_handle( num_threads=num_threads, + max_intra_op_parallelism=max_intra_op_parallelism, display_name=display_name, shared_name=shared_name) self._resource_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._resource, handle_device=context.context().device_name) else: self._resource = gen_dataset_ops.thread_pool_handle( - num_threads=num_threads, display_name=display_name) + num_threads=num_threads, + max_intra_op_parallelism=max_intra_op_parallelism, + display_name=display_name) class _ThreadPoolDataset(dataset_ops.Dataset): @@ -69,10 +73,7 @@ class _ThreadPoolDataset(dataset_ops.Dataset): return gen_dataset_ops.thread_pool_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._thread_pool._resource, # pylint: disable=protected-access - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_shapes(self): @@ -87,6 +88,8 @@ class _ThreadPoolDataset(dataset_ops.Dataset): return self._input_dataset.output_classes +# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable +# or make private / remove. def override_threadpool(dataset, thread_pool): """Returns a new dataset that uses the given thread pool for its operations. diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py index 765ef3f9b6d42c9d7af3ce4916731d37d65c9260..e0ce0a4ef15f6b9181bce92fb4d73bf1fab2e66c 100644 --- a/tensorflow/contrib/data/python/ops/unique.py +++ b/tensorflow/contrib/data/python/ops/unique.py @@ -20,8 +20,6 @@ from __future__ import print_function from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes @@ -44,17 +42,17 @@ def unique(): """ def _apply_fn(dataset): - return UniqueDataset(dataset) + return _UniqueDataset(dataset) return _apply_fn -class UniqueDataset(dataset_ops.Dataset): +class _UniqueDataset(dataset_ops.Dataset): """A `Dataset` contains the unique elements from its input.""" def __init__(self, input_dataset): """See `unique()` for details.""" - super(UniqueDataset, self).__init__() + super(_UniqueDataset, self).__init__() self._input_dataset = input_dataset if input_dataset.output_types not in (dtypes.int32, dtypes.int64, dtypes.string): @@ -65,10 +63,7 @@ class UniqueDataset(dataset_ops.Dataset): def _as_variant_tensor(self): return gen_dataset_ops.unique_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index a91c54153f465d638d3df1b24dab38987920d825..eba0dd0ea330e29db0ea8e68ee14767fcb8ddad0 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -77,6 +77,7 @@ py_library( "//tensorflow/python:device_util", "//tensorflow/python:distribute", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:training", "//tensorflow/python:variable_scope", @@ -148,6 +149,7 @@ py_library( ], deps = [ ":mirrored_strategy", + ":multi_worker_strategy", ":one_device_strategy", ":tpu_strategy", "//tensorflow/contrib/optimizer_v2:training", @@ -311,7 +313,6 @@ cuda_py_test( tags = [ "multi_and_single_gpu", "no_pip", - "noguitar", # TODO(b/109653107): test is flaky. ], ) @@ -447,8 +448,10 @@ py_library( srcs_version = "PY2AND3", deps = [ ":values", + "//tensorflow/contrib/all_reduce:all_reduce_py", "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", ], @@ -496,6 +499,7 @@ cuda_py_test( additional_deps = [ ":combinations", ":cross_tower_ops", + ":multi_worker_test_base", ":values", "@absl_py//absl/testing:parameterized", "//tensorflow/python:array_ops", @@ -505,6 +509,7 @@ cuda_py_test( "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], + shard_count = 15, tags = [ "multi_and_single_gpu", "no_pip", @@ -585,3 +590,22 @@ cuda_py_test( "notsan", ], ) + +cuda_py_test( + name = "metrics_v1_test", + srcs = ["metrics_v1_test.py"], + additional_deps = [ + ":combinations", + "@absl_py//absl/testing:parameterized", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:test", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 98e7228f24d8ca8cec594e40f1937fd2415a1c38..9a8ea4aa48b8cf4c5906f18d8bddacc224e0b644 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -47,6 +47,7 @@ from absl.testing import parameterized import six from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib +from tensorflow.contrib.distribute.python import multi_worker_strategy from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib from tensorflow.contrib.optimizer_v2 import adam as adam_v2 @@ -320,10 +321,6 @@ default_strategy = NamedDistribution( one_device_strategy = NamedDistribution( "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"), required_gpus=None) -tpu_strategy_single_iteration = NamedDistribution( - "TPUSingleIteration", - lambda: tpu_lib.TPUStrategy(iterations_per_step=1), - required_tpu=True) tpu_strategy = NamedDistribution("TPU", tpu_lib.TPUStrategy, required_tpu=True) # Note that we disable prefetching for testing since prefetching makes # the input non-deterministic. @@ -338,6 +335,34 @@ mirrored_strategy_with_two_gpus = NamedDistribution( ["/gpu:0", "/gpu:1"], prefetch_on_device=False), required_gpus=2) +multi_worker_strategy_with_cpu = NamedDistribution( + "MultiWorkerCPU", + lambda: multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster={ + "worker": [ + "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" + ] + }, + num_gpus_per_worker=0), 0) +multi_worker_strategy_with_one_gpu = NamedDistribution( + "MultiWorker1GPU", + lambda: multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster={ + "worker": [ + "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" + ] + }, + num_gpus_per_worker=1), 1) +multi_worker_strategy_with_two_gpus = NamedDistribution( + "MultiWorker2GPUs", + lambda: multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster={ + "worker": [ + "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" + ] + }, + num_gpus_per_worker=2), 2) + adam_optimizer_v1_fn = NamedObject( "AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1)) gradient_descent_optimizer_v1_fn = NamedObject( diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py index a411b880e80291e50516c180fa618056cbee78d3..1009c3c0124c254ee2b69ccc161c9a108bfb855c 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import six from tensorflow.contrib.distribute.python import cross_tower_utils @@ -234,7 +235,13 @@ class ReductionToOneDeviceCrossTowerOps(CrossTowerOps): def _group_value_by_device(per_device_values): """Group values into sublists by their devices. - This grouping is needed to call the all-reduce library. + This grouping is needed to call the all-reduce library because it expects a + list of the following form: + [(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ... + (grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ... + (grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ... + ... + ] Args: per_device_values: a list of PerDevice obejcts. @@ -322,7 +329,17 @@ class ConcatAndSplitPacker(object): # TODO(zhengxq): it is also possible to optimize away all the concat # as well. num_splits = self.num_packs - total_grad_size = array_ops.size(concat_grads) + + # The array_ops.size function will sometimes remove static shapes. So if + # all gradient shapes are defined, we use another method to get the + # total size. + # TODO(yuefengz): move this logic to array_ops.size. + if all([g.shape.is_fully_defined() for g, _ in tower_grads_and_vars]): + total_grad_size = sum( + [g.shape.num_elements() for g, _ in tower_grads_and_vars]) + else: + total_grad_size = array_ops.size(concat_grads) + split_size = total_grad_size // num_splits split_size_last = total_grad_size - split_size * (num_splits - 1) split_sizes = [split_size] * (num_splits - 1) + [split_size_last] @@ -412,6 +429,31 @@ class AggregateSmallTensorPacker(object): self.packing) +def _pack_tensors(device_grads, + num_packs=0, + agg_small_grads_max_bytes=0, + agg_small_grads_max_group=0): + """Pack tensors if specified.""" + if num_packs > 0: + tensor_packer = ConcatAndSplitPacker(num_packs) + device_grad_packs = tensor_packer.pack(device_grads) + elif agg_small_grads_max_bytes > 0 and agg_small_grads_max_group > 0: + tensor_packer = AggregateSmallTensorPacker(agg_small_grads_max_bytes, + agg_small_grads_max_group) + device_grad_packs = tensor_packer.pack(device_grads) + else: + tensor_packer = None + device_grad_packs = device_grads + return device_grad_packs, tensor_packer + + +def _unpack_tensors(reduced, tensor_packer=None): + """Unpack tensors if they are packed before all-reduce.""" + if tensor_packer: + return tensor_packer.unpack(reduced) + return reduced + + class AllReduceCrossTowerOps(CrossTowerOps): """Reduction using all reduce.""" @@ -440,10 +482,10 @@ class AllReduceCrossTowerOps(CrossTowerOps): agg_small_grads_max_group: see above. tensors. """ - self.all_reduce_alg = all_reduce_alg - self.num_packs = num_packs - self.agg_small_grads_max_bytes = agg_small_grads_max_bytes - self.agg_small_grads_max_group = agg_small_grads_max_group + self._all_reduce_alg = all_reduce_alg + self._num_packs = num_packs + self._agg_small_grads_max_bytes = agg_small_grads_max_bytes + self._agg_small_grads_max_group = agg_small_grads_max_group super(AllReduceCrossTowerOps, self).__init__() def _reduce(self, method_string, per_device_value, destinations): @@ -485,37 +527,24 @@ class AllReduceCrossTowerOps(CrossTowerOps): def _batch_all_reduce(self, method_string, per_device_values): """All reduce algorithm in a batch.""" + logging.info( + "batch_all_reduce invoked for batches size = %d with " + "algorithm = %s, num_packs = %d, agg_small_grads_max_bytes = %d and " + "agg_small_grads_max_group = %d", len(per_device_values), + self._all_reduce_alg, self._num_packs, self._agg_small_grads_max_bytes, + self._agg_small_grads_max_group) destinations = per_device_values[0].devices grouped = _group_value_by_device(per_device_values) - if self.num_packs > 0: - logging.info( - "batch_all_reduce invoked for batches size = %d with " - "algorithm = %s and num_packs = %d", len(per_device_values), - self.all_reduce_alg, self.num_packs) - tensor_packer = ConcatAndSplitPacker(self.num_packs) - device_grad_packs = tensor_packer.pack(grouped) - elif (self.agg_small_grads_max_bytes > 0 and - self.agg_small_grads_max_group > 0): - logging.info( - "batch_all_reduce invoked for batches size = %d with " - "algorithm = %s, agg_small_grads_max_bytes = %d and " - "agg_small_grads_max_group = %d", len(per_device_values), - self.all_reduce_alg, self.agg_small_grads_max_bytes, - self.agg_small_grads_max_group) - tensor_packer = AggregateSmallTensorPacker( - self.agg_small_grads_max_bytes, self.agg_small_grads_max_group) - device_grad_packs = tensor_packer.pack(grouped) - else: - logging.info( - "batch_all_reduce invoked for batches size = %d with algorithm = %s", - len(per_device_values), self.all_reduce_alg) - tensor_packer = None - device_grad_packs = grouped + + device_grad_packs, tensor_packer = _pack_tensors( + grouped, self._num_packs, self._agg_small_grads_max_bytes, + self._agg_small_grads_max_group) # The actual aggregation of the repacked gradients. Note that they are # sharded among different aggregation trees. So it is important to strike # the balance on num_splits. - if self.all_reduce_alg == "nccl": + if self._all_reduce_alg == "nccl": + # TODO(yuefengz): merge this into the all-reduce library. reduced = cross_tower_utils.aggregate_gradients_using_nccl( device_grad_packs) else: @@ -525,13 +554,137 @@ class AllReduceCrossTowerOps(CrossTowerOps): cross_tower_utils.aggregate_gradients_using_hierarchical_copy( destinations, device_grad_packs)) - if tensor_packer: - reduced = tensor_packer.unpack(reduced) - + reduced = _unpack_tensors(reduced, tensor_packer) return _ungroup_and_make_mirrored(reduced, per_device_values[0].devices, method_string) +AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple", + "alg shards limit") + + +class MultiWorkerAllReduce(AllReduceCrossTowerOps): + """All-reduce algorithms for distributed TensorFlow.""" + + def __init__(self, + worker_devices, + num_gpus_per_worker, + all_reduce_spec=("pscpu/pscpu", 2, -1), + num_packs=0, + agg_small_grads_max_bytes=0, + agg_small_grads_max_group=10): + """Initialize the all-reduce algorithm. + + Args: + worker_devices: a list of device strings for workers participating in + all-reduce. + num_gpus_per_worker: number of GPU devices per worker. + all_reduce_spec: a tuple or a named tuple or a list of tuples specifying + the all-reduce algorithm. + 1. The first element of a tuple is the name of the all-reduce algorithm. + Valid algorithm names are: "nccl", "nccl/xring", "nccl/rechd", + "nccl/pscpu", "xring", "pscpu", "psgpu", "pscpu/pscpu". Algorithms with + a "/" are hierarchical, so two all-reduces are executed, the first one + aggregates tensors within a worker and the second aggregates across + workers. + 2. The second element of a tuple is the number of shards when doing + all-reduce. Let's say its values is M, each tensor after packing will be + split into M shards and then M parallel all-reduces would be performed + before finally they are concatenated backed into a complete tensor. + 3. The third element is the maximum size of tensors that will be + applicable for the algorithm specified by the first element. For + example, if all_reduce_spec=[("nccl", 2, 1024), ("pscpu/pscpu", 2, -1)], + tensors with size not larger than 1024 bytes will be applied a 2-shard + "nccl" all-reduce and other tensors will be applied a 2-shard + "pscpu/pscpu" algorithm. The third elements should be in increasing + order across tuples and end with -1 which indicates infinity. + num_packs: see AllReduceCrossTowerOps. + agg_small_grads_max_bytes: see AllReduceCrossTowerOps. + agg_small_grads_max_group: see AllReduceCrossTowerOps. + """ + self._worker_devices = worker_devices + self._num_gpus_per_worker = num_gpus_per_worker + super(MultiWorkerAllReduce, self).__init__( + num_packs=num_packs, + agg_small_grads_max_bytes=agg_small_grads_max_bytes, + agg_small_grads_max_group=agg_small_grads_max_group) + + def validate_and_complete_spec(spec): + """Validate and complete the all-reduce spec.""" + # TODO(yuefengz): support namedtuple. + if not isinstance(spec, tuple): + raise ValueError( + "A tuple is expected for all-reduce spec: %r" % all_reduce_spec) + if not spec or len(spec) > 3: + raise ValueError( + "Too many elements in the all-reduce spec tuple: %r" % spec) + if len(spec) == 1: + return AllReduceSpecTuple(spec[0], 1, -1) + elif len(spec) == 2: + return AllReduceSpecTuple(spec[0], spec[1], -1) + else: + return AllReduceSpecTuple(*spec) + + self._all_reduce_spec = [] + if isinstance(all_reduce_spec, six.string_types): + self._all_reduce_spec.append(AllReduceSpecTuple(all_reduce_spec, 1, -1)) + elif isinstance(all_reduce_spec, tuple): + self._all_reduce_spec.append(validate_and_complete_spec(all_reduce_spec)) + elif isinstance(all_reduce_spec, list): + self._all_reduce_spec = [ + validate_and_complete_spec(spec) for spec in all_reduce_spec + ] + + def _batch_all_reduce(self, method_string, per_device_values): + """All reduce algorithm in a batch.""" + logging.info( + "distributed batch_all_reduce invoked for batches size = %d with " + "allreduce_spec = %r, num_packs = %d, agg_small_grads_max_bytes = %d " + "and agg_small_grads_max_group = %d", len(per_device_values), + self._all_reduce_spec, self._num_packs, self._agg_small_grads_max_bytes, + self._agg_small_grads_max_group) + + destinations = sorted(per_device_values[0].devices) + device_grads = _group_value_by_device(per_device_values) + + # The all reduce library requires fully defined shapes. + # TODO(yuefengz): when tensor sharding is not needed, static shapes are not + # required as well. + for device_grad in device_grads: + for grad, _ in device_grad: + if not grad.shape.is_fully_defined(): + raise ValueError("Shape is unknown for node %r" % grad) + + remaining_grads = device_grads + aggregated_grads = [] + for spec_tuple in self._all_reduce_spec: + if spec_tuple.limit < 0: + this_grads = remaining_grads + remaining_grads = [] + else: + (this_grads, remaining_grads) = cross_tower_utils.split_grads_by_size( + spec_tuple.limit, remaining_grads) + if this_grads: + device_grad_packs, tensor_packer = _pack_tensors( + this_grads, self._num_packs, self._agg_small_grads_max_bytes, + self._agg_small_grads_max_group) + range_agg_grads = cross_tower_utils.sum_gradients_all_reduce( + self._worker_devices, device_grad_packs, len(self._worker_devices), + spec_tuple.alg, spec_tuple.shards, range(self._num_gpus_per_worker)) + range_agg_grads = _unpack_tensors(range_agg_grads, tensor_packer) + + if not aggregated_grads: + aggregated_grads = range_agg_grads + else: + assert len(aggregated_grads) == len(range_agg_grads) + for i in range(len(aggregated_grads)): + aggregated_grads[i] += range_agg_grads[i] + assert not remaining_grads + + return _ungroup_and_make_mirrored(aggregated_grads, destinations, + method_string) + + _dgx1_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py index 2a266326088def94a5c1bee11ab6ec1a0ccf0f49..fed5505d92ef2544215069736c166a67d6141708 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py @@ -24,6 +24,7 @@ from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import values as value_lib from tensorflow.python.eager import context from tensorflow.python.eager import test @@ -75,7 +76,7 @@ def _make_mirrored_indexed_slices(devices, values, indices, dense_shape): _cpu_device = "/device:CPU:0" -class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): +class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase): def _assert_indexed_slices_equal(self, left, right): self.assertIsInstance(left, ops.IndexedSlices) @@ -94,7 +95,7 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): self.assertEqual(type(left), type(right)) self.assertEqual(left.devices, right.devices) if isinstance(list(left._index.values())[0], ops.IndexedSlices): - for (d, v) in left._index.iteritems(): + for (d, v) in left._index.items(): self._assert_indexed_slices_equal(v, right._index[d]) elif context.executing_eagerly(): self.assertEqual([v.numpy() for v in left._index.values()], @@ -104,51 +105,7 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): self.assertEqual( sess.run(list(left._index.values())), list(right._index.values())) - # TODO(yuefengz): decouple the num_gpus check from distribution in - # combinations module so that we can pass in devices instead of a distribution - # strategy. - reduction_to_one_combinations = combinations.combine( - cross_tower_ops=[ - combinations.NamedObject( - "DefaultReductionToOneDeviceCrossTowerOps", - cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()), - combinations.NamedObject( - "ReductionToCPUDeviceCrossTowerOps", - cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( - reduce_to_device=_cpu_device)), - combinations.NamedObject( - "AccumulateNCrossTowerOp", - cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( - accumulation_fn=math_ops.accumulate_n)), - ], - distribution=[ - combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus - ], - mode=["graph", "eager"]) - allreduce_combinations = combinations.combine( - cross_tower_ops=[ - combinations.NamedObject( - "AllReduce", - cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 1, 0, 0)), - combinations.NamedObject( - "HierarchicalCopy", - cross_tower_ops_lib.AllReduceCrossTowerOps( - "hierarchical_copy", 8, 0, 0)), - combinations.NamedObject( - "AllReduceNoGradientRepacking", - cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 0, 0, 0)), - combinations.NamedObject( - "HierarchicalCopyAggregateSmallTensors", - cross_tower_ops_lib.AllReduceCrossTowerOps( - "hierarchical_copy", 0, 100, 10)) - ], - distribution=[combinations.mirrored_strategy_with_two_gpus], - mode=["graph", "eager"]) - - @combinations.generate(reduction_to_one_combinations + allreduce_combinations) - def testReductionAndBroadcast(self, cross_tower_ops, distribution): + def _testReductionAndBroadcast(self, cross_tower_ops, distribution): devices = distribution.worker_devices values = [constant_op.constant(float(d)) for d in range(len(devices))] @@ -208,20 +165,70 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): cross_tower_ops.broadcast(constant_op.constant(1.), destinations), _fake_mirrored(1., destinations)) + +class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase): + # TODO(yuefengz): decouple the num_gpus check from distribution in + # combinations module so that we can pass in devices instead of a distribution + # strategy. + reduction_to_one_combinations = combinations.combine( + cross_tower_ops=[ + combinations.NamedObject( + "DefaultReductionToOneDeviceCrossTowerOps", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()), + combinations.NamedObject( + "ReductionToCPUDeviceCrossTowerOps", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( + reduce_to_device=_cpu_device)), + combinations.NamedObject( + "AccumulateNCrossTowerOp", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( + accumulation_fn=math_ops.accumulate_n)), + ], + distribution=[ + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus + ], + mode=["graph", "eager"]) + allreduce_combinations = combinations.combine( + cross_tower_ops=[ + combinations.NamedObject( + "AllReduce", + cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 1, 0, 0)), + combinations.NamedObject( + "HierarchicalCopy", + cross_tower_ops_lib.AllReduceCrossTowerOps( + "hierarchical_copy", 8, 0, 0)), + combinations.NamedObject( + "AllReduceNoGradientRepacking", + cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 0, 0, 0)), + combinations.NamedObject( + "HierarchicalCopyAggregateSmallTensors", + cross_tower_ops_lib.AllReduceCrossTowerOps( + "hierarchical_copy", 0, 100, 10)) + ], + distribution=[combinations.mirrored_strategy_with_two_gpus], + mode=["graph", "eager"]) + + @combinations.generate(reduction_to_one_combinations + allreduce_combinations) + def testReductionAndBroadcast(self, cross_tower_ops, distribution): + with distribution.scope(): + self._testReductionAndBroadcast(cross_tower_ops, distribution) + def testChooseAlgorithm(self): device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) - self.assertEqual(result.all_reduce_alg, "hierarchical_copy") - self.assertEqual(result.num_packs, 8) + self.assertEqual(result._all_reduce_alg, "hierarchical_copy") + self.assertEqual(result._num_packs, 8) # if there are only 4 devices device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]] result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) - self.assertEqual(result.all_reduce_alg, "nccl") - self.assertEqual(result.num_packs, 1) + self.assertEqual(result._all_reduce_alg, "nccl") + self.assertEqual(result._num_packs, 1) # if devices links contain each device itself device_links = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 5], [0, 1, 2, 3, 6], @@ -229,16 +236,16 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): [2, 4, 5, 6, 7], [3, 4, 5, 6, 7]] result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) - self.assertEqual(result.all_reduce_alg, "hierarchical_copy") - self.assertEqual(result.num_packs, 8) + self.assertEqual(result._all_reduce_alg, "hierarchical_copy") + self.assertEqual(result._num_packs, 8) # if not dgx1-like links device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]] result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) - self.assertEqual(result.all_reduce_alg, "nccl") - self.assertEqual(result.num_packs, 1) + self.assertEqual(result._all_reduce_alg, "nccl") + self.assertEqual(result._num_packs, 1) @combinations.generate(combinations.combine( mode=["graph", "eager"], @@ -316,5 +323,44 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): self._assert_values_equal(total_mirrored_without_dups, result) +class MultiWorkerCrossTowerOpsTest(multi_worker_test_base.MultiWorkerTestBase, + CrossTowerOpsTestBase): + + worker_devices = [ + "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" + ] + multi_worker_allreduce_combinations = combinations.combine( + cross_tower_ops=[ + combinations.NamedObject( + "MultiWorkerAllReduce", + cross_tower_ops_lib.MultiWorkerAllReduce( + worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 0, 0)), + combinations.NamedObject( + "MultiWorkerAllReducePack", + cross_tower_ops_lib.MultiWorkerAllReduce( + worker_devices, 2, ("pscpu/pscpu", 2, -1), 1, 0, 0)), + combinations.NamedObject( + "MultiWorkerAllReduceAggregation", + cross_tower_ops_lib.MultiWorkerAllReduce( + worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 100, 10)), + combinations.NamedObject( + "MultiWorkerAllReduceMultipleSpecs", + cross_tower_ops_lib.MultiWorkerAllReduce( + worker_devices, 2, [("pscpu/pscpu", 2, 100), + ("xring", 2, -1)], 0, 0, 0)), + ], + distribution=[ + combinations.multi_worker_strategy_with_cpu, + combinations.multi_worker_strategy_with_one_gpu, + combinations.multi_worker_strategy_with_two_gpus + ], + mode=["graph"]) + + @combinations.generate(multi_worker_allreduce_combinations) + def testReductionAndBroadcast(self, cross_tower_ops, distribution): + with distribution.scope(): + self._testReductionAndBroadcast(cross_tower_ops, distribution) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py index 137fabf4c739bb41104bceb9274df8284deef86d..2bb088e704c584598b863b1b836166af2a5bb12c 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_utils.py +++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py @@ -21,6 +21,7 @@ from __future__ import print_function import collections as pycoll from tensorflow.contrib import nccl +from tensorflow.contrib.all_reduce.python import all_reduce from tensorflow.contrib.distribute.python import values as value_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -158,6 +159,148 @@ def aggregate_single_gradient_using_copy(grad_and_vars, use_mean, return (grad, v), None +def group_device_names(devices, group_size): + """Group device names into groups of group_size. + + Args: + devices: a list of canonical device strings. + group_size: integer which is equal to or greater than 1. + + Returns: + list of lists of devices, where each inner list is group_size long, + and each device appears at least once in an inner list. If + len(devices) % group_size == 0 then each device will appear exactly once. + + Raises: + ValueError: if group_size > len(devices) + """ + num_devices = len(devices) + if group_size > num_devices: + raise ValueError( + 'only %d devices, but group_size=%d' % (num_devices, group_size)) + num_groups = ( + num_devices // group_size + (1 if (num_devices % group_size != 0) else 0)) + groups = [[] for i in range(num_groups)] + for i in range(num_groups * group_size): + groups[i % num_groups].append(devices[i % num_devices]) + return groups + + +def split_grads_by_size(threshold_size, device_grads): + """Break gradients into two sets according to tensor size. + + Args: + threshold_size: int size cutoff for small vs large tensor. + device_grads: List of lists of (gradient, variable) tuples. The outer + list is over devices. The inner list is over individual gradients. + + Returns: + small_grads: Subset of device_grads where shape is <= threshold_size + elements. + large_grads: Subset of device_grads where shape is > threshold_size + elements. + """ + small_grads = [] + large_grads = [] + for dl in device_grads: + small_dl = [] + large_dl = [] + for (g, v) in dl: + tensor_size = g.get_shape().num_elements() + if tensor_size <= threshold_size: + small_dl.append([g, v]) + else: + large_dl.append([g, v]) + if small_dl: + small_grads.append(small_dl) + if large_dl: + large_grads.append(large_dl) + return small_grads, large_grads + + +def sum_grad_and_var_all_reduce(grad_and_vars, + num_workers, + alg, + gpu_indices, + aux_devices=None, + num_shards=1): + """Apply all-reduce algorithm over specified gradient tensors.""" + with ops.name_scope('allreduce'): + # Note that each grad_and_vars looks like the following: + # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) + scaled_grads = [g for g, _ in grad_and_vars] + if alg == 'nccl': + summed_grads = nccl.all_sum(scaled_grads) + elif alg == 'xring': + summed_grads = all_reduce.build_ring_all_reduce( + scaled_grads, num_workers, num_shards, gpu_indices, math_ops.add) + elif alg == 'nccl/xring': + summed_grads = all_reduce.build_nccl_then_ring(scaled_grads, num_shards, + math_ops.add) + elif alg == 'nccl/rechd': + summed_grads = all_reduce.build_nccl_then_recursive_hd( + scaled_grads, math_ops.add) + elif alg == 'nccl/pscpu': + summed_grads = all_reduce.build_nccl_then_shuffle( + scaled_grads, aux_devices, math_ops.add, math_ops.add_n) + elif alg == 'pscpu/pscpu': + second_gather_devices = aux_devices[:num_shards] + summed_grads = all_reduce.build_shuffle_then_shuffle( + scaled_grads, aux_devices, second_gather_devices, math_ops.add_n) + elif alg in ['pscpu', 'psgpu']: + summed_grads = all_reduce.build_shuffle_all_reduce( + scaled_grads, aux_devices, math_ops.add_n) + else: + raise ValueError('unsupported all_reduce alg: ', alg) + + result = [] + for (_, v), g in zip(grad_and_vars, summed_grads): + result.append([g, v]) + return result + + +def sum_gradients_all_reduce(dev_prefixes, tower_grads, num_workers, alg, + num_shards, gpu_indices): + """Apply all-reduce algorithm over specified gradient tensors. + + Args: + dev_prefixes: list of prefix strings to use to generate PS device names. + tower_grads: the gradients to reduce. + num_workers: number of worker processes across entire job. + alg: the all-reduce algorithm to apply. + num_shards: alg-specific sharding factor. + gpu_indices: indices of local GPUs in order usable for ring-reduce. + + Returns: + list of reduced tensors + """ + alg_contains_shuffle = any([n in alg for n in ['pscpu', 'psgpu']]) + is_hierarchical = '/' in alg + if 'pscpu' in alg: + aux_devices = [prefix + '/cpu:0' for prefix in dev_prefixes] + elif 'psgpu' in alg: + aux_devices = [ + prefix + '/gpu:%d' % i + for i in range(len(gpu_indices)) + for prefix in dev_prefixes + ] + else: + aux_devices = ['/job:localhost/cpu:0'] + # Auxiliary devices for hierarchical all-reduces. + aux_device_groups = group_device_names( + aux_devices, num_shards if alg_contains_shuffle else 1) + group_index = 0 + reduced_gv_list = [] + for grad_and_vars in zip(*tower_grads): + reduced_gv_list.append( + sum_grad_and_var_all_reduce( + grad_and_vars, num_workers, alg, gpu_indices, aux_devices + if is_hierarchical else aux_device_groups[group_index], num_shards)) + group_index = (group_index + 1) % len(aux_device_groups) + new_tower_grads = [list(x) for x in zip(*reduced_gv_list)] + return new_tower_grads + + def extract_ranges(index_list, range_size_limit=32): """Extract consecutive ranges and singles from index_list. @@ -330,7 +473,7 @@ def unpack_small_tensors(tower_grads, packing): for dev_idx, gv_list in enumerate(tower_grads): gv_list = list(gv_list) new_gv_list = gv_list[num_packed:] - for i in xrange(0, num_packed): + for i in range(num_packed): k = '%d:%d' % (dev_idx, i) gpt = packing[k] gv = unpack_grad_tuple(gv_list[i], gpt) diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6c6bf143098c1bba64d47efce1bfface7682683d --- /dev/null +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -0,0 +1,438 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for V1 metrics.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.distribute.python import combinations +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import test +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics +from tensorflow.python.ops import variables + + +def _labeled_dataset_fn(): + # First four batches of x: labels, predictions -> (labels == predictions) + # 0: 0, 0 -> True; 1: 1, 1 -> True; 2: 2, 2 -> True; 3: 3, 0 -> False + # 4: 4, 1 -> False; 5: 0, 2 -> False; 6: 1, 0 -> False; 7: 2, 1 -> False + # 8: 3, 2 -> False; 9: 4, 0 -> False; 10: 0, 1 -> False; 11: 1, 2 -> False + # 12: 2, 0 -> False; 13: 3, 1 -> False; 14: 4, 2 -> False; 15: 0, 0 -> True + return dataset_ops.Dataset.range(1000).map( + lambda x: {"labels": x % 5, "predictions": x % 3}).batch(4) + + +def _boolean_dataset_fn(): + # First four batches of labels, predictions: {TP, FP, TN, FN} + # with a threshold of 0.5: + # T, T -> TP; F, T -> FP; T, F -> FN + # F, F -> TN; T, T -> TP; F, T -> FP + # T, F -> FN; F, F -> TN; T, T -> TP + # F, T -> FP; T, F -> FN; F, F -> TN + return dataset_ops.Dataset.from_tensor_slices({ + "labels": [True, False, True, False], + "predictions": [True, True, False, False]}).repeat().batch(3) + + +def _threshold_dataset_fn(): + # First four batches of labels, predictions: {TP, FP, TN, FN} + # with a threshold of 0.5: + # True, 1.0 -> TP; False, .75 -> FP; True, .25 -> FN + # False, 0.0 -> TN; True, 1.0 -> TP; False, .75 -> FP + # True, .25 -> FN; False, 0.0 -> TN; True, 1.0 -> TP + # False, .75 -> FP; True, .25 -> FN; False, 0.0 -> TN + return dataset_ops.Dataset.from_tensor_slices({ + "labels": [True, False, True, False], + "predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch(3) + + +def _regression_dataset_fn(): + return dataset_ops.Dataset.from_tensor_slices({ + "labels": [1., .5, 1., 0.], + "predictions": [1., .75, .25, 0.]}).repeat() + + +def all_combinations(): + return combinations.combine( + distribution=[combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus], + mode=["graph"]) + + +# TODO(josh11b): Test metrics.recall_at_top_k, metrics.average_precision_at_k, +# metrics.precision_at_k +class MetricsV1Test(test.TestCase, parameterized.TestCase): + + def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn): + with ops.Graph().as_default(), distribution.scope(): + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() + value, update = distribution.call_for_each_tower( + metric_fn, iterator.get_next()) + update = distribution.group(update) + self.evaluate(variables.local_variables_initializer()) + # TODO(josh11b): Once we switch to using a global batch size for input, + # replace "distribution.num_towers" with "1". + batches_per_update = distribution.num_towers + + # Update variables using the first `num_towers` batches. + self.evaluate(update) + self.assertAllClose(expected_fn(batches_per_update), self.evaluate(value), + 0.001, msg="After first update") + + # Update variables using the second `num_towers` batches. + self.evaluate(update) + self.assertAllClose(expected_fn(2 * batches_per_update), + self.evaluate(value), + 0.001, + msg="After second update") + + if batches_per_update == 1: # Consume 4 input batches + self.evaluate(update) + self.assertAllClose(expected_fn(3 * batches_per_update), + self.evaluate(value), + 0.001, + msg="After third update") + self.evaluate(update) + self.assertAllClose(expected_fn(4 * batches_per_update), + self.evaluate(value), + 0.001, + msg="After fourth update") + + @combinations.generate(all_combinations()) + def testMean(self, distribution): + def _dataset_fn(): + return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch(4) + + def _expected_fn(num_batches): + # Mean(0..3) = 1.5, Mean(0..7) = 3.5, Mean(0..11) = 5.5, etc. + return num_batches * 2 - 0.5 + + self._test_metric(distribution, _dataset_fn, metrics.mean, _expected_fn) + + @combinations.generate(all_combinations()) + def testAccuracy(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.accuracy(labels, predictions) + + def _expected_fn(num_batches): + return [3./4, 3./8, 3./12, 4./16][num_batches - 1] + + self._test_metric( + distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanPerClassAccuracy(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.mean_per_class_accuracy( + labels, predictions, num_classes=5) + + def _expected_fn(num_batches): + mean = lambda x: sum(x) / len(x) + return [mean([1., 1., 1., 0., 0.]), + mean([0.5, 0.5, 0.5, 0., 0.]), + mean([1./3, 1./3, 0.5, 0., 0.]), + mean([0.5, 1./3, 1./3, 0., 0.])][num_batches - 1] + + self._test_metric( + distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanIOU(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.mean_iou( + labels, predictions, num_classes=5) + + def _expected_fn(num_batches): + mean = lambda x: sum(x) / len(x) + return [mean([1./2, 1./1, 1./1, 0.]), # no class 4 in first batch + mean([1./4, 1./4, 1./3, 0., 0.]), + mean([1./6, 1./6, 1./5, 0., 0.]), + mean([2./8, 1./7, 1./7, 0., 0.])][num_batches - 1] + + self._test_metric( + distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanTensor(self, distribution): + def _dataset_fn(): + dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float) + # Want to produce a fixed, known shape, so drop remainder when batching. + dataset = dataset.apply(batching.batch_and_drop_remainder(4)) + return dataset + + def _expected_fn(num_batches): + # Mean(0, 4, ..., 4 * num_batches - 4) == 2 * num_batches - 2 + # Mean(1, 5, ..., 4 * num_batches - 3) == 2 * num_batches - 1 + # Mean(2, 6, ..., 4 * num_batches - 2) == 2 * num_batches + # Mean(3, 7, ..., 4 * num_batches - 1) == 2 * num_batches + 1 + first = 2. * num_batches - 2. + return [first, first + 1., first + 2., first + 3.] + + self._test_metric( + distribution, _dataset_fn, metrics.mean_tensor, _expected_fn) + + @combinations.generate(all_combinations()) + def testAUCROC(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.auc(labels, predictions, num_thresholds=8, curve="ROC", + summation_method="careful_interpolation") + + def _expected_fn(num_batches): + return [0.5, 7./9, 0.8, 0.75][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testAUCPR(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.auc(labels, predictions, num_thresholds=8, curve="PR", + summation_method="careful_interpolation") + + def _expected_fn(num_batches): + return [0.797267, 0.851238, 0.865411, 0.797267][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalseNegatives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_negatives(labels, predictions) + + def _expected_fn(num_batches): + return [1., 1., 2., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalseNegativesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_negatives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[1.], [1.], [2.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTrueNegatives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_negatives(labels, predictions) + + def _expected_fn(num_batches): + return [0., 1., 2., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTrueNegativesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_negatives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[0.], [1.], [2.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalsePositives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_positives(labels, predictions) + + def _expected_fn(num_batches): + return [1., 2., 2., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalsePositivesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_positives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[1.], [2.], [2.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTruePositives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_positives(labels, predictions) + + def _expected_fn(num_batches): + return [1., 2., 3., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTruePositivesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_positives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[1.], [2.], [3.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testPrecision(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.precision(labels, predictions) + + def _expected_fn(num_batches): + return [0.5, 0.5, 0.6, 0.5][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testPrecisionAtThreshold(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.precision_at_thresholds(labels, predictions, [0.5]) + + def _expected_fn(num_batches): + return [[0.5], [0.5], [0.6], [0.5]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testRecall(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.recall(labels, predictions) + + def _expected_fn(num_batches): + return [0.5, 2./3, 0.6, 0.5][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testRecallAtThreshold(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.recall_at_thresholds(labels, predictions, [0.5]) + + def _expected_fn(num_batches): + return [[0.5], [2./3], [0.6], [0.5]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanSquaredError(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.mean_squared_error(labels, predictions) + + def _expected_fn(num_batches): + return [0., 1./32, 0.208333, 0.15625][num_batches - 1] + + self._test_metric( + distribution, _regression_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testRootMeanSquaredError(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.root_mean_squared_error(labels, predictions) + + def _expected_fn(num_batches): + return [0., 0.176777, 0.456435, 0.395285][num_batches - 1] + + self._test_metric( + distribution, _regression_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testSensitivityAtSpecificity(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.sensitivity_at_specificity(labels, predictions, 0.8) + + def _expected_fn(num_batches): + return [0.5, 2./3, 0.6, 0.5][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testSpecificityAtSensitivity(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.specificity_at_sensitivity(labels, predictions, 0.95) + + def _expected_fn(num_batches): + return [0., 1./3, 0.5, 0.5][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index 5c056a7c73def2f1fb4bbe0df4d3f82fdabda3df..aeeb9553e6044a0a928936597400e582e0329b95 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -56,6 +56,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): is_tpu=[True])) def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss, is_tpu): + # TODO(priyag): Remove this once the step TPU Strategy is stable. + if is_tpu: + self.skipTest("TPU tests are WIP.") + with distribution.scope(): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) @@ -84,8 +88,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): for _ in range(10): run_step() - weights.append(self.evaluate(distribution.fetch(layer.kernel))) - biases.append(self.evaluate(distribution.fetch(layer.bias))) + weights.append(self.evaluate(layer.kernel)) + biases.append(self.evaluate(layer.bias)) if is_tpu: with self.test_session() as sess: @@ -111,6 +115,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): is_tpu=[True])) def testOptimizerInsideModelFn(self, distribution, optimizer_fn, is_tpu): + # TODO(priyag): Remove this once the step TPU Strategy is stable. + if is_tpu: + self.skipTest("TPU tests are WIP.") + created_variables = [] trainable_variables = [] @@ -186,7 +194,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # towers will re-execute UPDATE_OPS of previous towers. update_ops_in_cross_tower_mode=[True])) + combinations.combine( - distribution=[combinations.tpu_strategy_single_iteration], + distribution=[combinations.tpu_strategy], optimizer_fn=[ combinations.gradient_descent_optimizer_v1_fn, combinations.gradient_descent_optimizer_v2_fn @@ -198,6 +206,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): renorm, is_tpu, update_ops_in_cross_tower_mode): """Verifies that moving mean updates are reduced across towers.""" + # TODO(priyag): Remove this once the step TPU Strategy is stable. + if is_tpu: + self.skipTest("TPU tests are WIP.") + with distribution.scope(): num_towers = len(distribution.worker_devices) model_fn, dataset_fn, batchnorm = batchnorm_example( @@ -242,7 +254,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): for _ in range(10): run_step() - moving_means = self.evaluate(distribution.fetch(batchnorm.moving_mean)) + moving_means = self.evaluate(batchnorm.moving_mean) # We make sure that the moving_mean is updated as if the sample mean is # calculated over all towers. @@ -279,12 +291,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): mode=["graph"], use_callable_loss=[True, False]) + combinations.combine(mode=["eager"], use_callable_loss=[True])) + combinations.combine( - distribution=[combinations.tpu_strategy_single_iteration], + distribution=[combinations.tpu_strategy], is_tpu=[True], mode=["graph"], use_callable_loss=[True, False]))) def testMeanVsSum(self, distribution, optimizer_fn, loss_reduction, use_callable_loss, is_tpu): + # TODO(priyag): Remove this once the step TPU Strategy is stable. + if is_tpu: + self.skipTest("TPU tests are WIP.") + with distribution.scope(): all_vars = [] @@ -329,7 +345,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): v = all_vars[0] self.assertTrue(all([v is vi for vi in all_vars[1:]])) - weight = numpy.squeeze(self.evaluate(distribution.fetch(v))) + weight = numpy.squeeze(self.evaluate(v)) # Our model is: # predict = x * w # loss = (predict - y)^2 diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 6eadba976b7fb0e2902959755ff62fc2c8bc3660..d8668b398f227d6f74d3b6d4dbc316213d39536e 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -108,6 +108,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): if tower_local is not None: kwargs["trainable"] = False + # Ignore user-specified caching device, not needed for mirrored variables. + kwargs.pop("caching_device", None) + # TODO(josh11b,apassos): It would be better if variable initialization # was never recorded on the tape instead of having to do this manually # here. @@ -118,7 +121,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): if i > 0: # Give replicas meaningful distinct names: var0name = index[devices[0]].name.split(":")[0] - kwargs["name"] = "%s/replica_%d" % (var0name, i) + # We append a / to variable names created on towers with id > 0 to + # ensure that we ignore the name scope and instead use the given + # name as the absolute name of the variable. + kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Initialize replicas with the same value: if context.executing_eagerly(): kwargs["initial_value"] = array_ops.identity( @@ -258,8 +264,15 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): {t.device: t.merge_args for t in threads}) merge_kwargs = values.regroup( {t.device: t.merge_kwargs for t in threads}) - merge_result = threads[0].merge_fn( - self, *merge_args, **merge_kwargs) + # We capture the name_scope of the MTT when we call merge_fn + # to ensure that if we have opened a name scope in the MTT, + # it will be respected when executing the merge function. We only + # capture the name_scope from the first MTT and assume it is + # the same for all other MTTs. + mtt_captured_name_scope = threads[0].captured_name_scope + with ops.name_scope(mtt_captured_name_scope): + merge_result = threads[0].merge_fn( + self, *merge_args, **merge_kwargs) for t in threads: t.merge_result = values.select_device(t.device, merge_result) finally: @@ -272,8 +285,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def map(self, map_over, fn, *args, **kwargs): # TODO(josh11b): In eager mode, use one thread per device. index = {} - i = 0 - for m in map_over: + for i, m in enumerate(map_over): d = self._devices[i % len(self._devices)] with ops.device(d): l = index.get(d, []) @@ -309,14 +321,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): value_destination_pairs) def _update(self, var, fn, *args, **kwargs): - # TODO(josh11b): Also support TowerLocalVariables here? If so, args and - # kwargs don't need to be mirrored. - assert isinstance(var, values.MirroredVariable) # TODO(josh11b): In eager mode, use one thread per device. + assert isinstance(var, values.DistributedVariable) updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): + # If args and kwargs are not mirrored, the value is returned as is. updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) @@ -333,6 +344,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): **values.select_device_mirrored(d, kwargs)) return values.regroup(updates, values.Mirrored) + def read_var(self, tower_local_var): + """Read the aggregate value of a tower-local variable.""" + if isinstance(tower_local_var, values.TowerLocalVariable): + return tower_local_var._get_cross_tower() # pylint: disable=protected-access + assert isinstance(tower_local_var, values.Mirrored) + return array_ops.identity(tower_local_var.get()) + def _fetch(self, val, destination, fn): """Return a copy of `val` or `fn(val)` on `destination`.""" if isinstance(val, values.TowerLocalVariable): @@ -428,6 +446,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): self.merge_args = None self.merge_kwargs = None self.merge_result = None + self.captured_name_scope = None # We use a thread.Event for the main thread to signal when this # thread should start running (`should_run`), and another for # this thread to transfer control back to the main thread @@ -451,13 +470,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): self._variable_creator_stack = self.graph._variable_creator_stack[:] self._captured_var_scope = variable_scope.get_variable_scope() # Adding a "/" at end lets us re-enter this scope later. - self._captured_name_scope = self.graph.get_name_scope() - if self._captured_name_scope: - self._captured_name_scope += "/" + self._name_scope = self.graph.get_name_scope() + if self._name_scope: + self._name_scope += "/" if self.tower_id > 0: - if not self._captured_name_scope: - self._captured_name_scope = "" - self._captured_name_scope += "tower_%d/" % self.tower_id + if not self._name_scope: + self._name_scope = "" + self._name_scope += "tower_%d/" % self.tower_id def run(self): # pylint: disable=protected-access @@ -473,7 +492,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): _enter_graph(self.graph), \ MirroredTowerContext(self.distribution, self.tower_id), \ ops.device(self.device), \ - ops.name_scope(self._captured_name_scope), \ + ops.name_scope(self._name_scope), \ variable_scope.variable_scope( self._captured_var_scope, reuse=self.tower_id > 0), \ variable_scope.variable_creator_scope(self.variable_creator_fn): @@ -499,6 +518,10 @@ class MirroredTowerContext(distribute_lib.TowerContext): t.merge_fn = fn t.merge_args = args t.merge_kwargs = kwargs + t.captured_name_scope = t.graph.get_name_scope() + # Adding a "/" at end lets us re-enter this scope later. + if t.captured_name_scope: + t.captured_name_scope += "/" t.has_paused.set() t.should_run.wait() t.should_run.clear() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 3f9a02b249dde9a66056ed8952b664bbc3f74ead..cb150692de8bffcf590e68f91945442ce7d6dfe0 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -337,6 +337,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase): all_v_sum = {} all_v_mean = {} + components_sum = {} + components_mean = {} def model_fn(device_id): tower_context = distribute_lib.get_tower_context() @@ -350,21 +352,33 @@ class MirroredStrategyVariableCreationTest(test.TestCase): v_mean.assign(6.0 * device_id)] all_v_sum[device_id] = v_sum all_v_mean[device_id] = v_mean - return updates, v_sum, v_mean + c_sum = v_sum.get() + c_mean = v_mean.get() + components_sum[device_id] = c_sum + components_mean[device_id] = c_mean + self.assertIsNot(v_sum, c_sum) + self.assertIsNot(v_mean, c_mean) + return updates, v_sum, v_mean, c_sum, c_mean dist = mirrored_strategy.MirroredStrategy( ["/device:GPU:0", "/device:CPU:0"]) with dist.scope(): # Create "sum" and "mean" versions of TowerLocalVariables. - ret_ops, ret_v_sum, ret_v_mean = dist.call_for_each_tower( - model_fn, dist.worker_device_index, run_concurrently=False) + ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = ( + dist.call_for_each_tower( + model_fn, dist.worker_device_index, run_concurrently=False)) # Should see the same wrapping instance in all towers. self.assertIs(all_v_sum[0], ret_v_sum) self.assertIs(all_v_mean[0], ret_v_mean) - for i in range(1, dist.num_towers): - self.assertIs(all_v_sum[0], all_v_sum[1]) - self.assertIs(all_v_mean[0], all_v_mean[1]) + self.assertIs(all_v_sum[0], all_v_sum[1]) + self.assertIs(all_v_mean[0], all_v_mean[1]) + + # Regroup should recover the same wrapper. + self.assertIs(ret_v_sum, regrouped_sum) + self.assertIs(ret_v_mean, regrouped_mean) + self.assertIsNot(components_sum[0], components_sum[1]) + self.assertIsNot(components_mean[0], components_mean[1]) # Apply updates self.evaluate(variables.global_variables_initializer()) @@ -385,14 +399,13 @@ class MirroredStrategyVariableCreationTest(test.TestCase): # Without get(device), should return the value you get by # applying the reduction across all towers (whether you use - # fetch(), get(), or nothing). - self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum))) - self.assertEqual(expected_mean, self.evaluate(dist.fetch(ret_v_mean))) + # read_var(), get(), or nothing). + self.assertEqual(expected_sum, self.evaluate(dist.read_var(ret_v_sum))) + self.assertEqual(expected_mean, self.evaluate(dist.read_var(ret_v_mean))) self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get())) self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get())) - if not context.executing_eagerly(): - self.assertEqual(expected_sum, self.evaluate(ret_v_sum)) - self.assertEqual(expected_mean, self.evaluate(ret_v_mean)) + self.assertEqual(expected_sum, self.evaluate(ret_v_sum)) + self.assertEqual(expected_mean, self.evaluate(ret_v_mean)) # NOTE(priyag): Names and name scopes are ignored in eager, hence we are not # testing this in eager mode. @@ -438,6 +451,74 @@ class MirroredStrategyVariableCreationTest(test.TestCase): self.assertEquals("foo/" + name + ":0", v0.name) self.assertEquals("tower_1/foo/" + name + ":0", v1.name) + # variable_scope.variable() respects name scopes when creating + # variables. On the other hand variable_scope.get_variable() ignores name + # scopes when creating variables. We test both methods of creating variables + # to make sure that we have the same variable names in both cases. + def testNameScopeWithVariable(self): + def in_cross_tower(_): + c = variable_scope.variable(1.0, name="c") + return c + + def model_fn(): + b = variable_scope.variable(1.0, name="b") + with ops.name_scope("foo"): + c = distribute_lib.get_tower_context().merge_call(in_cross_tower) + return b, c + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + with ops.name_scope("main"): + a = variable_scope.variable(1.0, name="a") + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + result_b = result[0] + result_c = result[1] + self.assertIsInstance(result_b, values.DistributedValues) + self.assertIsInstance(result_c, values.DistributedValues) + a0, a1 = dist.unwrap(a) + b0, b1 = dist.unwrap(result_b) + c0, c1 = dist.unwrap(result_c) + self.assertEquals("main/a:0", a0.name) + self.assertEquals("main/a/replica_1:0", a1.name) + self.assertEquals("main/b:0", b0.name) + self.assertEquals("main/b/replica_1:0", b1.name) + self.assertEquals("main/foo/c:0", c0.name) + self.assertEquals("main/foo/c/replica_1:0", c1.name) + + def testNameScopeWithGetVariable(self): + def in_cross_tower(_): + c = variable_scope.get_variable("c", [1]) + return c + + def model_fn(): + b = variable_scope.get_variable("b", [1]) + with ops.name_scope("foo"): + c = distribute_lib.get_tower_context().merge_call(in_cross_tower) + return b, c + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + with ops.name_scope("main"): + a = variable_scope.get_variable("a", [1]) + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + result_b = result[0] + result_c = result[1] + self.assertIsInstance(result_b, values.DistributedValues) + self.assertIsInstance(result_c, values.DistributedValues) + a0, a1 = dist.unwrap(a) + b0, b1 = dist.unwrap(result_b) + c0, c1 = dist.unwrap(result_c) + self.assertEquals("a:0", a0.name) + self.assertEquals("a/replica_1:0", a1.name) + self.assertEquals("b:0", b0.name) + self.assertEquals("b/replica_1:0", b1.name) + self.assertEquals("c:0", c0.name) + self.assertEquals("c/replica_1:0", c1.name) + def testDynamicRnnVariables(self): def model_fn(): inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]]) @@ -462,6 +543,43 @@ class MirroredStrategyVariableCreationTest(test.TestCase): _, v1 = dist.unwrap(v) self.assertStartsWith(v1.name, "tower_1/") + @test_util.run_in_graph_and_eager_modes(config=config) + def testTowerLocalVariableUpdate(self): + with context.graph_mode(): + + def model_fn(): + tower_context = distribute_lib.get_tower_context() + with tower_context.tower_local_var_scope("sum"): + v_sum = variable_scope.variable(1.0) + self.assertTrue(isinstance(v_sum, values.TowerLocalVariable)) + return v_sum + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:GPU:1"]) + + def update(var, value): + return var.assign(value) + + with dist.scope(): + ret_v_sum = dist.call_for_each_tower(model_fn, run_concurrently=False) + update_ops = dist.unwrap(dist.update(ret_v_sum, update, 5.0)) + + # Initialize variables. + self.evaluate(variables.global_variables_initializer()) + # Assert that the aggregated value of the tower local vars is the sum of + # the individual values before running the update ops. + self.assertEquals(1.0, self.evaluate( + ret_v_sum.get(dist._devices[0]).read_value())) + self.assertEquals(2.0, self.evaluate(ret_v_sum)) + + # Apply updates. + self.evaluate(update_ops) + # Assert that the aggregated value of the tower local vars is the sum of + # the individual values after running the update ops. + self.assertEquals(5.0, self.evaluate( + ret_v_sum.get(dist._devices[0]).read_value())) + self.assertEquals(10.0, self.evaluate(ret_v_sum)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/monitor_test.py b/tensorflow/contrib/distribute/python/monitor_test.py index 4fdb9bf69b4f6ad76b79fd298f5303f24a1bd455..2892ce439494320a115b8eae0025a132841c4a8f 100644 --- a/tensorflow/contrib/distribute/python/monitor_test.py +++ b/tensorflow/contrib/distribute/python/monitor_test.py @@ -52,11 +52,11 @@ class MonitorTest(test.TestCase, parameterized.TestCase): self.assertEqual(1, len(layer.trainable_variables)) mirrored_weight_variable = layer.trainable_variables[0] - start_error = self.evaluate(distribution.fetch(mirrored_weight_variable)) + start_error = self.evaluate(mirrored_weight_variable) start_error = abs(numpy.array(start_error) - 1) monitor.run_steps(9) - end_error = self.evaluate(distribution.fetch(mirrored_weight_variable)) + end_error = self.evaluate(mirrored_weight_variable) end_error = abs(numpy.array(end_error) - 1) self.assertGreaterEqual(start_error, end_error) diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy.py b/tensorflow/contrib/distribute/python/multi_worker_strategy.py index a552b370ebf359464afcaf3211119e73434e0dfb..0f21a427320510635279f80c11711e81715ec37c 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_strategy.py +++ b/tensorflow/contrib/distribute/python/multi_worker_strategy.py @@ -121,7 +121,7 @@ class MultiWorkerMirroredStrategy(MirroredStrategy): worker: [device_util.canonicalize(worker, '/device:CPU:0')] for worker in self._workers } - self._devices = nest.flatten(self._worker_device_map.values()) + self._devices = nest.flatten(self._worker_device_map) super(MultiWorkerMirroredStrategy, self).__init__( devices=self._devices, prefetch_on_device=prefetch_on_device) diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 09b6d4a515ab46879520f304cd5ef60469512380..7f4bab9d93814eb70a2a1586fc291a16b2766b90 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -102,6 +102,10 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): with ops.device(self._device), distribute_lib.UpdateContext(self._device): return fn(*args, **kwargs) + def read_var(self, tower_local_var): + """Read the aggregate value of a tower-local variable.""" + return array_ops.identity(tower_local_var) + def _fetch(self, val, destination, fn): """Return a copy of `val` or `fn(val)` on `destination`.""" with ops.device(self._device): diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py index abd3a65ac4e19ece6b69b9834f4218fde55b60c2..a2d736e42271ab1627240949b99088ed3f0746f6 100644 --- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -59,8 +59,8 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): for _ in range(10): run_step() - weights.append(self.evaluate(distribution.fetch(layer.kernel))) - biases.append(self.evaluate(distribution.fetch(layer.bias))) + weights.append(self.evaluate(layer.kernel)) + biases.append(self.evaluate(layer.bias)) error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py index 75c5ec9659d193e77d219ba79977615d58841d64..2ee94d8f70868c07ca217dd4d433585458efa8d8 100644 --- a/tensorflow/contrib/distribute/python/step_fn_test.py +++ b/tensorflow/contrib/distribute/python/step_fn_test.py @@ -50,8 +50,8 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase): for _ in range(10): run_step() - weights.append(self.evaluate(distribution.fetch(layer.kernel))) - biases.append(self.evaluate(distribution.fetch(layer.bias))) + weights.append(self.evaluate(layer.kernel)) + biases.append(self.evaluate(layer.bias)) error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index 2b4ad9f146bc1d6a987fbeecbb05122946137154..d2fe8b3b1efabf7b35c070a82d01595f3fa51bf9 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -106,13 +106,13 @@ class DistributionTestBase(test.TestCase): before_list = [] after_list = [] for g, v in g_v: - fetched = d.fetch(v) + fetched = d.read_var(v) before_list.append(fetched) # control_dependencies irrelevant but harmless in eager execution with ops.control_dependencies([fetched]): g = d.reduce("sum", g, destinations=v) with ops.control_dependencies(d.unwrap(d.update(v, update, g))): - after_list.append(d.fetch(v)) + after_list.append(d.read_var(v)) return before_list, after_list for i in range(10): @@ -159,12 +159,12 @@ class DistributionTestBase(test.TestCase): before_list = [] after_list = [] for g, v in g_v: - fetched = d.fetch(v) + fetched = d.read_var(v) before_list.append(fetched) with ops.control_dependencies([fetched]): g = d.reduce("sum", g, destinations=v) with ops.control_dependencies(d.unwrap(d.update(v, update, g))): - after_list.append(d.fetch(v)) + after_list.append(d.read_var(v)) return before_list, after_list before_out, after_out = step() @@ -184,7 +184,7 @@ class DistributionTestBase(test.TestCase): with d.scope(): map_in = [constant_op.constant(i) for i in range(10)] map_out = d.map(map_in, lambda x, y: x * y, 2) - observed = d.fetch(d.reduce("sum", map_out)) + observed = d.reduce("sum", map_out) expected = 90 # 2 * (0 + 1 + ... + 9) self.assertEqual(expected, observed.numpy()) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 75441786a615fc0d87b4c4b0b45b9384d678c1d3..b177e09adbc89684ff885d2903a30cc3696a2140 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -21,11 +21,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import itertools - from tensorflow.contrib import tpu from tensorflow.contrib.distribute.python import one_device_strategy -from tensorflow.contrib.distribute.python import values from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops @@ -36,86 +33,83 @@ from tensorflow.python.util import nest class TPUStrategy(one_device_strategy.OneDeviceStrategy): """Experimental TPU distribution strategy implementation.""" - def __init__(self, - num_cores_per_host=2, - iterations_per_step=2): + def __init__(self, num_cores_per_host=2): # TODO(isaprykin): Generalize the defaults. They are currently tailored for # the unit test. super(TPUStrategy, self).__init__('/cpu:0') # TODO(isaprykin): Auto-detect number of cores and hosts. self._num_cores_per_host = num_cores_per_host - # TODO(isaprykin): This might have to be per-call. - self._iterations_per_step = iterations_per_step + # TODO(priyag): This should not be hardcoded here. + self._host = '/task:0/device:CPU:0' def distribute_dataset(self, dataset_fn): - return values.PerIterationDataset( - self._call_dataset_fn(dataset_fn), self._iterations_per_step, - self._num_cores_per_host) - - def _call_for_each_tower(self, fn, *args, **kwargs): - kwargs.pop('run_concurrently', None) - - inputs = {'args': args, 'kwargs': kwargs} - flat_inputs = nest.flatten(inputs) - - feed_mask = [isinstance(f, values.PerIteration) for f in flat_inputs] + # TODO(priyag): Perhaps distribute across cores here. + return self._call_dataset_fn(dataset_fn) - feeds = lambda: itertools.compress(flat_inputs, feed_mask) - shapes = [f.get_shape() for f in feeds()] + # TODO(priyag): Deal with OutOfRange errors. + def run_steps_on_dataset(self, fn, iterator, iterations): + # Enqueue ops + shapes = nest.flatten(iterator.output_shapes) if any([not s.is_fully_defined() for s in shapes]): raise ValueError( 'TPU currently requires fully defined shapes. Either use ' 'set_shape() on the input tensors or use ' 'dataset.apply(map_and_batch(..., drop_remainder=True)).') - types = [f.get_dtype() for f in feeds()] - - def infeed_input(i): - """Get input, split it and then enqueue.""" - iteration_inputs = [f.get(i) for f in feeds()] - infeed_inputs = [[inputs_per_core[core_id] - for inputs_per_core in iteration_inputs] - for core_id in range(self._num_cores_per_host)] - - infeed_ops = [] - for core_id, infeed_input in enumerate(infeed_inputs): - infeed_ops.append( + types = nest.flatten(iterator.output_types) + + def enqueue_ops_fn(): + """Enqueue ops for one iteration.""" + control_deps = [] + sharded_inputs = [] + with ops.device(self._host): + for _ in range(self._num_cores_per_host): + # Use control dependencies to ensure a deterministic ordering. + with ops.control_dependencies(control_deps): + inputs = nest.flatten(iterator.get_next()) + control_deps.extend(inputs) + sharded_inputs.append(inputs) + + enqueue_ops = [] + for core_id, shard_input in enumerate(sharded_inputs): + enqueue_ops.append( tpu_ops.infeed_enqueue_tuple( - inputs=infeed_input, shapes=shapes, device_ordinal=core_id)) + inputs=shard_input, shapes=shapes, device_ordinal=core_id)) + return enqueue_ops - with ops.control_dependencies(infeed_ops): + def enqueue_ops_loop_body(i): + with ops.control_dependencies(enqueue_ops_fn()): return i + 1 - with ops.device('/task:0/device:CPU:0'): + with ops.device(self._host): enqueue_ops = control_flow_ops.while_loop( - lambda i: i < self._iterations_per_step, - infeed_input, [constant_op.constant(0)], + lambda i: i < iterations, + enqueue_ops_loop_body, + [constant_op.constant(0)], parallel_iterations=1) - def dequeueing_fn(*args, **kwargs): - """Dequeue input arguments and supply them to `fn`.""" - del args, kwargs + # Dequeue ops + def dequeue_fn(): dequeued = tpu.infeed_dequeue_tuple(dtypes=types, shapes=shapes) - dequeued = iter(dequeued) + return nest.pack_sequence_as(iterator.output_shapes, dequeued) - fn_inputs = [] - for inp, is_feed in zip(flat_inputs, feed_mask): - if is_feed: - fn_inputs.append(next(dequeued)) - else: - fn_inputs.append(inp) - - fn_inputs = nest.pack_sequence_as(inputs, fn_inputs) - return fn(*fn_inputs['args'], **fn_inputs['kwargs']) + # Wrap `fn` for repeat. + run_fn = lambda: fn(dequeue_fn()) + # Repeat def iterate_on_tpu(): - return tpu.repeat(self._iterations_per_step, dequeueing_fn, []) + return tpu.repeat(iterations, run_fn, []) - with one_device_strategy._OneDeviceTowerContext(self): # pylint: disable=protected-access - tpu_result = tpu.batch_parallel( - iterate_on_tpu, [], num_shards=self._num_cores_per_host) + # Re-write and distribute computation. + tpu_result = tpu.batch_parallel( + iterate_on_tpu, [], num_shards=self._num_cores_per_host) return control_flow_ops.group(tpu_result, enqueue_ops) + def _call_for_each_tower(self, fn, *args, **kwargs): + kwargs.pop('run_concurrently', None) + with one_device_strategy._OneDeviceTowerContext(self): # pylint: disable=protected-access + return fn(*args, **kwargs) + def _reduce(self, method_string, value, destinations): del destinations # TPU is graph mode only. Rely on implicit Send/Recv. if method_string == 'mean': diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 9572ade8e497fa13a7ca0746399d3e0237ee79fd..9a48928a9530c3b24b603acf8b0b7584f5294b9e 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -26,7 +26,6 @@ import weakref import six -from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.distribute.python import input_ops from tensorflow.contrib.distribute.python import prefetching_ops_v2 from tensorflow.python.eager import context @@ -43,7 +42,7 @@ from tensorflow.python.util import nest # pylint: disable=line-too-long -# TODO(josh11b): Should device values be strings or DeviceSpec objects +# TODO(josh11b): Should device values be strings or DeviceSpec objects? # Not sure DeviceSpec objects are usable as a dict key. class DistributedValues(object): """Holds a map from device to values. Either PerDevice or Mirrored.""" @@ -163,9 +162,16 @@ class PerDevice(DistributedValues): pass -class Mirrored(DistributedValues): +# Note that unlike PerDevice, Mirrored values inherit from +# DistributedDelegate and so can be used directly in cross-tower mode. +class Mirrored(DistributedDelegate): """Holds a map from device to values which are kept in sync.""" - pass + + def _get_cross_tower(self): + device = device_util.canonicalize(device_util.current()) + if device in self._index: + return self._index[device] + return list(self._index.values())[0] def _assign_on_device(device, variable, tensor): @@ -186,6 +192,10 @@ class DistributedVariable(DistributedDelegate): # Child class must set self._primary_var before calling # super(...).__init__(index). self._common_name = self._primary_var.name.split(":")[0] + # Use a weakref to make it easy to map from the contained values + # to the container without introducing a reference cycle. + for v in six.itervalues(index): + v._distributed_container = weakref.ref(self) # pylint: disable=protected-access super(DistributedVariable, self).__init__(index) @property @@ -238,17 +248,6 @@ class DistributedVariable(DistributedDelegate): pass -# Register a conversion function which reads the value of the variable, -# allowing instances of the class to be used as tensors. -def _tensor_conversion(var, dtype=None, name=None, as_ref=False): - # Try to avoid assignments to and other mutations of MirroredVariable - # state except through a DistributionStrategy.update() call. - assert not as_ref - return ops.internal_convert_to_tensor( - var.get(), dtype=dtype, name=name, as_ref=as_ref) - - -ops.register_tensor_conversion_function(DistributedVariable, _tensor_conversion) ops.register_dense_tensor_like_type(DistributedVariable) @@ -292,10 +291,6 @@ class MirroredVariable(DistributedVariable, Mirrored, """Holds a map from device to variables whose values are kept in sync.""" def __init__(self, index, primary_var): - # Use a weakref to make it easy to map from the contained values - # to the container without introducing a reference cycle. - for v in six.itervalues(index): - v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access self._primary_var = primary_var super(MirroredVariable, self).__init__(index) @@ -342,6 +337,20 @@ class MirroredVariable(DistributedVariable, Mirrored, return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} +# Register a conversion function which reads the value of the variable, +# allowing instances of the class to be used as tensors. +def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False): + # Try to avoid assignments to and other mutations of MirroredVariable + # state except through a DistributionStrategy.update() call. + assert not as_ref + return ops.internal_convert_to_tensor( + var.get(), dtype=dtype, name=name, as_ref=as_ref) + + +ops.register_tensor_conversion_function(MirroredVariable, + _tensor_conversion_mirrored) + + class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): """Class for defining how to restore a TowerLocalVariable.""" @@ -350,7 +359,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): # We use a callable so that we don't have to evaluate this expression # in the case where we are trying to restore instead of save. def tensor(): - return distribute_lib.get_distribution_strategy().fetch( + return distribute_lib.get_distribution_strategy().read_var( tower_local_variable) spec = saver.BaseSaverBuilder.SaveSpec( tensor=tensor, @@ -431,6 +440,17 @@ class TowerLocalVariable(DistributedVariable, PerDevice, return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} +# Register a conversion function for TowerLocalVariable which allows as_ref to +# be true. +def _tensor_conversion_tower_local(var, dtype=None, name=None, as_ref=False): + return ops.internal_convert_to_tensor( + var.get(), dtype=dtype, name=name, as_ref=as_ref) + + +ops.register_tensor_conversion_function(TowerLocalVariable, + _tensor_conversion_tower_local) + + def _devices_match(d1, d2): return device_util.canonicalize(d1) == device_util.canonicalize(d2) @@ -478,40 +498,40 @@ def regroup(per_device, wrap_class=PerDevice): same_id = False break # Consider three cases where same_id is true: - # * If v0 is a MirroredVariable (and same_id means it is the same - # across all devices), we want to return it. We check - # MirroredVariable specifically since it can look like it - # has a _mirrored_container member since its members do. - # * If v0 is a member of a mirrored variable, in which case - # hasattr(v0, "_mirrored_container") is true, we want to - # return the MirroredVariable that contains it using the - # _mirrored_container logic below. This case can trigger + # * If v0 is a DistributedVariable (a MirroredVariable or + # TowerLocalVariable, and same_id means it is the same across all + # devices), we want to return it. We check DistributedVariable + # specifically since it can look like it has a + # _distributed_container member since its members do. + # * If v0 is a member of a distributed variable, in which case + # hasattr(v0, "_distributed_container") is true, we want to + # return the DistributedVariable that contains it using the + # _distributed_container logic below. This case can trigger # same_id when there is only one device. # * In any other situation, same_id means we return v0. - if same_id and (isinstance(v0, MirroredVariable) or - not hasattr(v0, "_mirrored_container")): + if same_id and (isinstance(v0, DistributedVariable) or + not hasattr(v0, "_distributed_container")): return v0 # Detect the case where each device has a parallel component of the - # same MirroredVariable. In this case we want to return the - # containing MirroredVariable, after a bunch of sanity checking. - # In particular, each component should have the same container, - # and the devices of the variables should match the keys of the - # per-device dictionary. - # TODO(josh11b): Do we need similar logic for TowerLocalVariables? - if hasattr(v0, "_mirrored_container"): + # same MirroredVariable (or TowerLocalVariable). In this case we + # want to return the containing MirroredVariable, after a bunch of + # sanity checking. In particular, each component should have the + # same container, and the devices of the variables should match the + # keys of the per-device dictionary. + if hasattr(v0, "_distributed_container"): # pylint: disable=protected-access assert not isinstance(v0, MirroredVariable), ( "ids = %s, items = %s" % ([id(v[1]) for v in items], items)) assert _devices_match(v0.device, items[0][0]), ( "v0.device = %s, items = %s" % (v0.device, items)) - mirrored_container = v0._mirrored_container() - assert mirrored_container is not None + distributed_container = v0._distributed_container() + assert distributed_container is not None for d, v in items[1:]: assert _devices_match(v.device, d), ( "v.device = %s, d = %s, items = %s" % (v.device, d, items)) - assert mirrored_container is v._mirrored_container() - return mirrored_container + assert distributed_container is v._distributed_container() + return distributed_container # pylint: enable=protected-access return wrap_class(per_device) @@ -593,8 +613,7 @@ class PerDeviceDataset(object): # TODO(priyag): If dropping remainder is not appropriate, find another # approach to distributing the dataset when not possible to divide evenly. # Possibly not an issue when we start using PartitionedDataset. - self._dataset = dataset.apply( - batching.batch_and_drop_remainder(len(devices))) + self._dataset = dataset.batch(len(devices), drop_remainder=True) def make_one_shot_iterator(self): """Get a one time use iterator for the distributed PerDeviceDataset.""" diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index 1c95758d96aba47e9581dde6411763e98b99a968..b0bd92c7b054b52b071e5d7601bdc48117464822 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -966,6 +966,18 @@ class TowerLocalVariableTest(test.TestCase): save_path = self._save_normal() self._restore_tower_local_sum(save_path) + def testTensorConversion(self): + with context.graph_mode(): + _, tower_local = _make_tower_local("sum") + converted = ops.internal_convert_to_tensor(tower_local, as_ref=False) + self.assertIsInstance(converted, ops.Tensor) + self.assertEqual(converted.dtype, tower_local.dtype) + + converted = ops.internal_convert_to_tensor(tower_local, as_ref=True) + # Resources variable are converted to tensors as well when as_ref is True. + self.assertIsInstance(converted, ops.Tensor) + self.assertEqual(converted.dtype, tower_local.dtype) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 23d9dbcd91a25e7cbb5d6cfea5d63ba8412f4255..ad00d1734dd14ed846522a33d888a5387cb25cc6 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -16,6 +16,13 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test") py_library( name = "bijectors_py", srcs = glob(["python/ops/bijectors/*.py"]), + deprecation = ("TensorFlow Distributions has migrated to " + + "TensorFlow Probability " + + "(https://github.com/tensorflow/probability). " + + "Deprecated copies remaining in tf.contrib.distributions " + + "are unmaintained, unsupported, and will be removed by " + + "late 2018. You should update all usage of " + + "`tf.contrib.distributions` to `tfp.distributions`."), srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/linalg:linalg_py", @@ -42,6 +49,13 @@ py_library( py_library( name = "distributions_py", srcs = ["__init__.py"] + glob(["python/ops/*.py"]), + deprecation = ("TensorFlow Distributions has migrated to " + + "TensorFlow Probability " + + "(https://github.com/tensorflow/probability). " + + "Deprecated copies remaining in tf.contrib.distributions " + + "are unmaintained, unsupported, and will be removed by " + + "late 2018. You should update all usage of " + + "`tf.contrib.distributions` to `tfp.distributions`."), srcs_version = "PY2AND3", deps = [ ":bijectors_py", @@ -940,6 +954,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "fill_triangular_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/fill_triangular_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "gumbel_test", size = "small", @@ -1118,6 +1151,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "scale_tril_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/scale_tril_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "sigmoid_test", size = "small", @@ -1235,6 +1287,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "transform_diagonal_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/transform_diagonal_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "weibull_test", size = "small", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 802538ba97578ce6cfe7e3555963ecd2fd014a66..5cec93c4df2e970f203253be6342bb292f296eb0 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -13,8 +13,6 @@ # limitations under the License. # ============================================================================== """Classes representing statistical distributions and ops for working with them. - -See the @{$python/contrib.distributions} guide. """ from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py index e281e81bdf0698c1f7b2f60fb27783dd1351773f..d1ce273499c8a646c0757844c91a785fa8d56ce4 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py @@ -61,6 +61,28 @@ class CholeskyOuterProductBijectorTest(test.TestCase): atol=0., rtol=1e-7) + def testNoBatchStaticJacobian(self): + x = np.eye(2) + bijector = bijectors.CholeskyOuterProduct() + + # The Jacobian matrix is 2 * tf.eye(2), which has jacobian determinant 4. + self.assertAllClose( + np.log(4), + self.evaluate(bijector.forward_log_det_jacobian(x, event_ndims=2))) + + def testNoBatchDynamicJacobian(self): + x = np.eye(2) + bijector = bijectors.CholeskyOuterProduct() + x_pl = array_ops.placeholder(dtypes.float32) + + with self.test_session(): + log_det_jacobian = bijector.forward_log_det_jacobian(x_pl, event_ndims=2) + + # The Jacobian matrix is 2 * tf.eye(2), which has jacobian determinant 4. + self.assertAllClose( + np.log(4), + log_det_jacobian.eval({x_pl: x})) + def testNoBatchStatic(self): x = np.array([[1., 0], [2, 1]]) # np.linalg.cholesky(y) y = np.array([[1., 2], [2, 5]]) # np.matmul(x, x.T) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py new file mode 100644 index 0000000000000000000000000000000000000000..caeaf2a0c6e4fff28c0edd82cb09ca0bcee85fc3 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py @@ -0,0 +1,98 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for FillTriangular bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class FillTriangularBijectorTest(test.TestCase): + """Tests the correctness of the FillTriangular bijector.""" + + @test_util.run_in_graph_and_eager_modes() + def testBijector(self): + x = np.float32(np.array([1., 2., 3.])) + y = np.float32(np.array([[3., 0.], + [2., 1.]])) + + b = bijectors.FillTriangular() + + y_ = self.evaluate(b.forward(x)) + self.assertAllClose(y, y_) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=1)) + self.assertAllClose(fldj, 0.) + + ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2)) + self.assertAllClose(ildj, 0.) + + @test_util.run_in_graph_and_eager_modes() + def testShape(self): + x_shape = tensor_shape.TensorShape([5, 4, 6]) + y_shape = tensor_shape.TensorShape([5, 4, 3, 3]) + + b = bijectors.FillTriangular(validate_args=True) + + x = array_ops.ones(shape=x_shape, dtype=dtypes.float32) + y_ = b.forward(x) + self.assertAllEqual(y_.shape.as_list(), y_shape.as_list()) + x_ = b.inverse(y_) + self.assertAllEqual(x_.shape.as_list(), x_shape.as_list()) + + y_shape_ = b.forward_event_shape(x_shape) + self.assertAllEqual(y_shape_.as_list(), y_shape.as_list()) + x_shape_ = b.inverse_event_shape(y_shape) + self.assertAllEqual(x_shape_.as_list(), x_shape.as_list()) + + y_shape_tensor = self.evaluate( + b.forward_event_shape_tensor(x_shape.as_list())) + self.assertAllEqual(y_shape_tensor, y_shape.as_list()) + x_shape_tensor = self.evaluate( + b.inverse_event_shape_tensor(y_shape.as_list())) + self.assertAllEqual(x_shape_tensor, x_shape.as_list()) + + @test_util.run_in_graph_and_eager_modes() + def testShapeError(self): + + b = bijectors.FillTriangular(validate_args=True) + + x_shape_bad = tensor_shape.TensorShape([5, 4, 7]) + with self.assertRaisesRegexp(ValueError, "is not a triangular number"): + b.forward_event_shape(x_shape_bad) + with self.assertRaisesOpError("is not a triangular number"): + self.evaluate(b.forward_event_shape_tensor(x_shape_bad.as_list())) + + y_shape_bad = tensor_shape.TensorShape([5, 4, 3, 2]) + with self.assertRaisesRegexp(ValueError, "Matrix must be square"): + b.inverse_event_shape(y_shape_bad) + with self.assertRaisesOpError("Matrix must be square"): + self.evaluate(b.inverse_event_shape_tensor(y_shape_bad.as_list())) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py new file mode 100644 index 0000000000000000000000000000000000000000..566a7b3dff9b5d97a1cb143e0b32fc15984c3a02 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py @@ -0,0 +1,69 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ScaleTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class ScaleTriLBijectorTest(test.TestCase): + """Tests the correctness of the ScaleTriL bijector.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + def testComputesCorrectValues(self): + shift = 1.61803398875 + x = np.float32(np.array([-1, .5, 2])) + y = np.float32(np.array([[np.exp(2) + shift, 0.], + [.5, np.exp(-1) + shift]])) + + b = bijectors.ScaleTriL(diag_bijector=bijectors.Exp(), + diag_shift=shift) + + y_ = self.evaluate(b.forward(x)) + self.assertAllClose(y, y_) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + @test_util.run_in_graph_and_eager_modes() + def testInvertible(self): + + # Generate random inputs from an unconstrained space, with + # event size 6 to specify 3x3 triangular matrices. + batch_shape = [2, 1] + x = np.float32(np.random.randn(*(batch_shape + [6]))) + b = bijectors.ScaleTriL(diag_bijector=bijectors.Softplus(), + diag_shift=3.14159) + y = self.evaluate(b.forward(x)) + self.assertAllEqual(y.shape, batch_shape + [3, 3]) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=1)) + ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2)) + self.assertAllClose(fldj, -ildj) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py index 45760a29ee42835da69ef63803ccec7ce82a5a8f..795f1993ba5c31bf5a26333f31f1bc73125bff07 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py @@ -151,16 +151,24 @@ class SinhArcsinhBijectorTest(test.TestCase): self.assertAllClose(y, bijector.forward(x).eval(), rtol=1e-4, atol=0.) self.assertAllClose(x, bijector.inverse(y).eval(), rtol=1e-4, atol=0.) - # Do the numpy calculation in float128 to avoid inf/nan. - y_float128 = np.float128(y) - self.assertAllClose( - np.log(np.cosh( - np.arcsinh(y_float128) / tailweight - skewness) / np.sqrt( - y_float128**2 + 1)) - - np.log(tailweight), - bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), - rtol=1e-4, - atol=0.) + # On IBM PPC systems, longdouble (np.float128) is same as double except that it can have more precision. + # Type double being of 8 bytes, can't hold square of max of float64 (which is also 8 bytes) and + # below test fails due to overflow error giving inf. So this check avoids that error by skipping square + # calculation and corresponding assert. + + if np.amax(y) <= np.sqrt(np.finfo(np.float128).max) and \ + np.fabs(np.amin(y)) <= np.sqrt(np.fabs(np.finfo(np.float128).min)): + + # Do the numpy calculation in float128 to avoid inf/nan. + y_float128 = np.float128(y) + self.assertAllClose( + np.log(np.cosh( + np.arcsinh(y_float128) / tailweight - skewness) / np.sqrt( + y_float128**2 + 1)) - + np.log(tailweight), + bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), + rtol=1e-4, + atol=0.) self.assertAllClose( -bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), bijector.forward_log_det_jacobian(x, event_ndims=0).eval(), diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6428a68702274fae384ae3de6d03f7ca126e2346 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py @@ -0,0 +1,66 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for TransformDiagonal bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class TransformDiagonalBijectorTest(test.TestCase): + """Tests correctness of the TransformDiagonal bijector.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + @test_util.run_in_graph_and_eager_modes() + def testBijector(self): + x = np.float32(np.random.randn(3, 4, 4)) + + y = x.copy() + for i in range(x.shape[0]): + np.fill_diagonal(y[i, :, :], np.exp(np.diag(x[i, :, :]))) + + exp = bijectors.Exp() + b = bijectors.TransformDiagonal(diag_bijector=exp) + + y_ = self.evaluate(b.forward(x)) + self.assertAllClose(y, y_) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=2)) + ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2)) + self.assertAllEqual( + fldj, + self.evaluate(exp.forward_log_det_jacobian( + np.array([np.diag(x_mat) for x_mat in x]), + event_ndims=1))) + self.assertAllEqual( + ildj, + self.evaluate(exp.inverse_log_det_jacobian( + np.array([np.diag(y_mat) for y_mat in y]), + event_ndims=1))) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py index 31d24aa9ea09007b8db40e4869371b1f62639ac7..bbbec2103aefd3f38a9b734bcd3f2e15fc8bb683 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py @@ -29,7 +29,9 @@ from tensorflow.contrib.distributions.python.ops import mvn_diag from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import categorical from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.linalg import linear_operator_diag @@ -540,5 +542,51 @@ class PadDynamicTest(_PadTest, test.TestCase): return False +class TestMoveDimension(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def test_move_dimension_static_shape(self): + + x = random_ops.random_normal(shape=[200, 30, 4, 1, 6]) + + x_perm = distribution_util.move_dimension(x, 1, 1) + self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 4, 1, 6]) + + x_perm = distribution_util.move_dimension(x, 0, 3) + self.assertAllEqual(x_perm.shape.as_list(), [30, 4, 1, 200, 6]) + + x_perm = distribution_util.move_dimension(x, 0, -2) + self.assertAllEqual(x_perm.shape.as_list(), [30, 4, 1, 200, 6]) + + x_perm = distribution_util.move_dimension(x, 4, 2) + self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 6, 4, 1]) + + @test_util.run_in_graph_and_eager_modes() + def test_move_dimension_dynamic_shape(self): + + x_ = random_ops.random_normal(shape=[200, 30, 4, 1, 6]) + x = array_ops.placeholder_with_default(input=x_, shape=None) + + x_perm = distribution_util.move_dimension(x, 1, 1) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [200, 30, 4, 1, 6]) + + x_perm = distribution_util.move_dimension(x, 0, 3) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [30, 4, 1, 200, 6]) + + x_perm = distribution_util.move_dimension(x, 0, -2) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [30, 4, 1, 200, 6]) + + x_perm = distribution_util.move_dimension(x, 4, 2) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [200, 30, 6, 4, 1]) + + x_perm = distribution_util.move_dimension(x, -1, 2) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [200, 30, 6, 4, 1]) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD b/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD index 03e26b198ea02ad1bef8bcd2f6076078ecd7df0b..42ecea034d77430924bd6f597bf42ec3f64fec92 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD @@ -34,7 +34,10 @@ py_test( name = "correlation_matrix_volumes_test", size = "medium", srcs = ["correlation_matrix_volumes_test.py"], - tags = ["no_pip"], + tags = [ + "no_pip", + "optonly", + ], deps = [ ":correlation_matrix_volumes_py", # For statistical testing diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py index 11ca90c4833d84b092f0b43a8f5404e3a11450cd..bb9b8043b2233b2109f51b5dde188d088fdb0d39 100644 --- a/tensorflow/contrib/distributions/python/ops/autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.python.framework import ops from tensorflow.python.ops.distributions import distribution as distribution_lib from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class Autoregressive(distribution_lib.Distribution): @@ -107,6 +108,14 @@ class Autoregressive(distribution_lib.Distribution): https://arxiv.org/abs/1606.05328 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, distribution_fn, sample0=None, diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py index 4714caad69ee4341d259f6677decdd5842931834..519077bc9ab1063a1135486cfae34656f3f68157 100644 --- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py +++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib +from tensorflow.python.util import deprecation __all__ = [ @@ -71,6 +72,14 @@ class BatchReshape(distribution_lib.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, distribution, batch_shape, @@ -352,6 +361,14 @@ class BatchReshape(distribution_lib.Distribution): return runtime_assertions +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def calculate_reshape(original_shape, new_shape, validate=False, name=None): """Calculates the reshaped dimensions (replacing up to one -1 in reshape).""" batch_shape_static = tensor_util.constant_value_as_shape(new_shape) @@ -384,6 +401,14 @@ def calculate_reshape(original_shape, new_shape, validate=False, name=None): return expanded_new_shape, batch_shape_static, validations +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def validate_init_args_statically(distribution, batch_shape): """Helper to __init__ which makes or raises assertions.""" if batch_shape.shape.ndims is not None: diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index 4965381ef33e14cef0e0339341d50c943d412d8f..e141f8b5c6423bd6cce4d09da6f49d55b3e25a24 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -24,6 +24,7 @@ @@CholeskyOuterProduct @@ConditionalBijector @@Exp +@@FillTriangular @@Gumbel @@Identity @@Inline @@ -36,12 +37,14 @@ @@PowerTransform @@RealNVP @@Reshape +@@ScaleTriL @@Sigmoid @@SinhArcsinh @@SoftmaxCentered @@Softplus @@Softsign @@Square +@@TransformDiagonal @@Weibull @@masked_autoregressive_default_template @@ -64,6 +67,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.chain import * from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import * from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import * from tensorflow.contrib.distributions.python.ops.bijectors.exp import * +from tensorflow.contrib.distributions.python.ops.bijectors.fill_triangular import * from tensorflow.contrib.distributions.python.ops.bijectors.gumbel import * from tensorflow.contrib.distributions.python.ops.bijectors.inline import * from tensorflow.contrib.distributions.python.ops.bijectors.invert import * @@ -75,12 +79,14 @@ from tensorflow.contrib.distributions.python.ops.bijectors.permute import * from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import * from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import * from tensorflow.contrib.distributions.python.ops.bijectors.reshape import * +from tensorflow.contrib.distributions.python.ops.bijectors.scale_tril import * from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import * from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh import * from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import * from tensorflow.contrib.distributions.python.ops.bijectors.softplus import * from tensorflow.contrib.distributions.python.ops.bijectors.softsign import * from tensorflow.contrib.distributions.python.ops.bijectors.square import * +from tensorflow.contrib.distributions.python.ops.bijectors.transform_diagonal import * from tensorflow.python.ops.distributions.bijector import * from tensorflow.python.ops.distributions.identity_bijector import Identity diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py index c9e31d7712f09f6c4b4cc6ae51a34c42a19c291d..4d6a46e7358933fdf512f49eae2673f35953c90a 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py @@ -23,6 +23,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ "AbsoluteValue", @@ -70,6 +71,14 @@ class AbsoluteValue(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="absolute_value"): """Instantiates the `AbsoluteValue` bijector. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py index b4c2939eb914d50475ba6b1c1e979a804090f641..25f29452c3949600b8a4153a8585dd7269bd3b2b 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -36,6 +37,14 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _as_tensor(x, name): """Convenience to convert to `Tensor` or leave as `None`.""" return None if x is None else ops.convert_to_tensor(x, name=name) @@ -97,6 +106,14 @@ class Affine(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, shift=None, scale_identity_multiplier=None, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py index 59f9742d576a7804f401d3a47ba31ae61d6c6e54..91301f15ad87e133777371b346864ecf7b964f27 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.linalg import linear_operator +from tensorflow.python.util import deprecation __all__ = [ @@ -88,6 +89,14 @@ class AffineLinearOperator(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, shift=None, scale=None, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py index cd792e2c8cf48602daf9fb5eb56b8c34bac050c7..460d906231bd30f8cec4fe21d42afe7b2a05805e 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -52,6 +53,14 @@ class AffineScalar(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, shift=None, scale=None, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py b/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py index 224cec8a63dba53a528490117efac890312fe8d5..f19f147dd645b4f805f1905899b44293284d4225 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -34,6 +35,14 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _undo_batch_normalization(x, mean, variance, @@ -128,6 +137,14 @@ class BatchNormalization(bijector.Bijector): Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, batchnorm_layer=None, training=True, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py index 16f959560ce0f171035b3ef0bd80b16dae1cc654..910774ea5bb4106a948567144c46c6db23a2c6e0 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -31,10 +32,26 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _use_static_shape(input_tensor, ndims): return input_tensor.shape.is_fully_defined() and isinstance(ndims, int) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _compute_min_event_ndims(bijector_list, compute_forward=True): """Computes the min_event_ndims associated with the give list of bijectors. @@ -142,6 +159,14 @@ class Chain(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, bijectors=None, validate_args=False, name=None): """Instantiates `Chain` bijector. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py index 268c8d03426d435dc38412ac1bd05c674bd05d2b..3e1e4fc82971b71792d193ea8518dd402e4a4d9d 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -69,6 +70,14 @@ class CholeskyOuterProduct(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="cholesky_outer_product"): """Instantiates the `CholeskyOuterProduct` bijector. @@ -173,7 +182,20 @@ class CholeskyOuterProduct(bijector.Bijector): axis=-1) fldj = p_float * np.log(2.) + sum_weighted_log_diag - return fldj + # We finally need to undo adding an extra column in non-scalar cases + # where there is a single matrix as input. + if x.get_shape().ndims is not None: + if x.get_shape().ndims == 2: + fldj = array_ops.squeeze(fldj, axis=-1) + return fldj + + shape = array_ops.shape(fldj) + maybe_squeeze_shape = array_ops.concat([ + shape[:-1], + distribution_util.pick_vector( + math_ops.equal(array_ops.rank(x), 2), + np.array([], dtype=np.int32), shape[-1:])], 0) + return array_ops.reshape(fldj, maybe_squeeze_shape) def _make_columnar(self, x): """Ensures non-scalar input has at least one column. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py index 9fc1bbf052b419d07a9db149b990c2b80190d72b..07627e1e45eae6b63d830b2adf036bdc3b1d2895 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distributions.python.ops.bijectors import power_transform +from tensorflow.python.util import deprecation __all__ = [ @@ -47,6 +48,14 @@ class Exp(power_transform.PowerTransform): over the event space. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="exp"): diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py new file mode 100644 index 0000000000000000000000000000000000000000..31a9ca27e519bc312813668bf621a875838f12a0 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py @@ -0,0 +1,165 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""FillTriangular bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.distributions import util as dist_util +from tensorflow.python.util import deprecation + + +__all__ = [ + "FillTriangular", +] + + +class FillTriangular(bijector.Bijector): + """Transforms vectors to triangular. + + Triangular matrix elements are filled in a clockwise spiral. + + Given input with shape `batch_shape + [d]`, produces output with + shape `batch_shape + [n, n]`, where + `n = (-1 + sqrt(1 + 8 * d))/2`. + This follows by solving the quadratic equation + `d = 1 + 2 + ... + n = n * (n + 1)/2`. + + #### Example + + ```python + b = tfb.FillTriangular(upper=False) + b.forward([1, 2, 3, 4, 5, 6]) + # ==> [[4, 0, 0], + # [6, 5, 0], + # [3, 2, 1]] + + b = tfb.FillTriangular(upper=True) + b.forward([1, 2, 3, 4, 5, 6]) + # ==> [[1, 2, 3], + # [0, 5, 6], + # [0, 0, 4]] + + ``` + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, + upper=False, + validate_args=False, + name="fill_triangular"): + """Instantiates the `FillTriangular` bijector. + + Args: + upper: Python `bool` representing whether output matrix should be upper + triangular (`True`) or lower triangular (`False`, default). + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._upper = upper + super(FillTriangular, self).__init__( + forward_min_event_ndims=1, + inverse_min_event_ndims=2, + validate_args=validate_args, + name=name) + + def _forward(self, x): + return dist_util.fill_triangular(x, upper=self._upper) + + def _inverse(self, y): + return dist_util.fill_triangular_inverse(y, upper=self._upper) + + def _forward_log_det_jacobian(self, x): + return array_ops.zeros_like(x[..., 0]) + + def _inverse_log_det_jacobian(self, y): + return array_ops.zeros_like(y[..., 0, 0]) + + def _forward_event_shape(self, input_shape): + batch_shape, d = input_shape[:-1], input_shape[-1].value + if d is None: + n = None + else: + n = vector_size_to_square_matrix_size(d, self.validate_args) + return batch_shape.concatenate([n, n]) + + def _inverse_event_shape(self, output_shape): + batch_shape, n1, n2 = (output_shape[:-2], + output_shape[-2].value, + output_shape[-1].value) + if n1 is None or n2 is None: + m = None + elif n1 != n2: + raise ValueError("Matrix must be square. (saw [{}, {}])".format(n1, n2)) + else: + m = n1 * (n1 + 1) / 2 + return batch_shape.concatenate([m]) + + def _forward_event_shape_tensor(self, input_shape_tensor): + batch_shape, d = input_shape_tensor[:-1], input_shape_tensor[-1] + n = vector_size_to_square_matrix_size(d, self.validate_args) + return array_ops.concat([batch_shape, [n, n]], axis=0) + + def _inverse_event_shape_tensor(self, output_shape_tensor): + batch_shape, n = output_shape_tensor[:-2], output_shape_tensor[-1] + if self.validate_args: + is_square_matrix = check_ops.assert_equal( + n, output_shape_tensor[-2], message="Matrix must be square.") + with ops.control_dependencies([is_square_matrix]): + n = array_ops.identity(n) + d = math_ops.cast(n * (n + 1) / 2, output_shape_tensor.dtype) + return array_ops.concat([batch_shape, [d]], axis=0) + + +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) +def vector_size_to_square_matrix_size(d, validate_args, name=None): + """Convert a vector size to a matrix size.""" + if isinstance(d, (float, int, np.generic, np.ndarray)): + n = (-1 + np.sqrt(1 + 8 * d)) / 2. + if float(int(n)) != n: + raise ValueError("Vector length is not a triangular number.") + return int(n) + else: + with ops.name_scope(name, "vector_size_to_square_matrix_size", [d]) as name: + n = (-1. + math_ops.sqrt(1 + 8. * math_ops.to_float(d))) / 2. + if validate_args: + with ops.control_dependencies([check_ops.assert_equal( + math_ops.to_float(math_ops.to_int32(n)), n, + message="Vector length is not a triangular number")]): + n = array_ops.identity(n) + return math_ops.cast(n, d.dtype) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py index e656a258e56e71898ecb719dd2af876f158cf799..71e562a927a30a17d695b81c566f981db7553ad9 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ "Gumbel", @@ -45,6 +46,14 @@ class Gumbel(bijector.Bijector): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=0., scale=1., diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py index 2bde956d1345129285acae4684256c5ac828b9a1..1504bd27204f728c0cb519159230e945128c4740 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -43,6 +44,14 @@ class Inline(bijector.Bijector): The above example is equivalent to the `Bijector` `Exp()`. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, forward_fn=None, inverse_fn=None, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py index 84a3289ba2160ed22a2bc7030dd612ba9ca6f6df..a648676d4b1956e5c27f67a71e6bd93d0d7fc97d 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ "Invert", @@ -40,6 +41,14 @@ class Invert(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, bijector, validate_args=False, name=None): """Creates a `Bijector` which swaps the meaning of `inverse` and `forward`. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py index 97000c17262d3efdef10274711364c2bc2083bd4..33b75a04d34fdd01bc0d854d4e5b9c45a737b122 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ "Kumaraswamy", @@ -44,6 +45,14 @@ class Kumaraswamy(bijector.Bijector): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, concentration1=None, concentration0=None, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py index 83667b0e80cfcc1c4f0617cdc739221f24439665..b8f2a4b2c731bdaee78692c036fb9f2fba4e3760 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import template as template_ops from tensorflow.python.ops import variable_scope as variable_scope_lib from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -186,6 +187,14 @@ class MaskedAutoregressiveFlow(bijector.Bijector): Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, shift_and_log_scale_fn, is_constant_jacobian=False, @@ -296,6 +305,14 @@ MASK_INCLUSIVE = "inclusive" MASK_EXCLUSIVE = "exclusive" +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE): """Generate the slices for building an autoregressive mask.""" # TODO(b/67594795): Better support of dynamic shape. @@ -313,6 +330,14 @@ def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE): return slices +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _gen_mask(num_blocks, n_in, n_out, @@ -327,6 +352,14 @@ def _gen_mask(num_blocks, return mask +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def masked_dense(inputs, units, num_blocks=None, @@ -399,6 +432,14 @@ def masked_dense(inputs, return layer.apply(inputs) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def masked_autoregressive_default_template( hidden_layers, shift_only=False, @@ -515,6 +556,14 @@ def masked_autoregressive_default_template( "masked_autoregressive_default_template", _fn) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _clip_by_value_preserve_grad(x, clip_value_min, clip_value_max, name=None): """Clips input while leaving gradient unaltered.""" with ops.name_scope(name, "clip_by_value_preserve_grad", diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py index 71903f705232f0c5e5e0b3271550b4ef938c4f9d..49e6192f067edec4890dcfa107876a5104c14dd4 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -55,6 +56,14 @@ class MatrixInverseTriL(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="matrix_inverse_tril"): """Instantiates the `MatrixInverseTriL` bijector. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py b/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py index 3f03592f314cc13e8a9ea7e2ae18c5bb1f14e74f..fb393218b6b47764f45b5055bbf15cc17aba219e 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -57,6 +58,14 @@ class Ordered(bijector.Bijector): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="ordered"): super(Ordered, self).__init__( forward_min_event_ndims=1, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py index 12a16a3f2ba3da53077307fd97d3f10d99b2c81f..f182a1adcbb6b11af2376cd271f903d50e50f1a0 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -74,6 +75,14 @@ class Permute(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, permutation, validate_args=False, name=None): """Creates the `Permute` bijector. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py index 71f123f2a998458edaa9c8da07ea2932f62625ca..16264fe728a334db347304500767ce5876f9db7e 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -41,6 +42,14 @@ class PowerTransform(bijector.Bijector): This bijector is equivalent to the `Exp` bijector when `c=0`. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, power=0., validate_args=False, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py index 66e8a5b9b356867424d1d47efaf848fc6903c371..773ae2446118051a61636bc21de6b81dfacda746 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import template as template_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -126,6 +127,14 @@ class RealNVP(bijector.Bijector): Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, num_masked, shift_and_log_scale_fn, @@ -228,6 +237,14 @@ class RealNVP(bijector.Bijector): return math_ops.reduce_sum(log_scale, axis=-1) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def real_nvp_default_template( hidden_layers, shift_only=False, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py index 5497c422e4d51e259435692dac722f801e8844ac..c8282229a30fabff0c4c267d0bdfcdbce4f5f3d9 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -36,10 +37,26 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _static_ndims_from_shape(shape): return shape.shape.with_rank_at_least(1)[0].value +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _ndims_from_shape(shape): return array_ops.shape(shape)[0] @@ -86,6 +103,14 @@ class Reshape(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, event_shape_out, event_shape_in=(-1,), validate_args=False, name=None): """Creates a `Reshape` bijector. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py new file mode 100644 index 0000000000000000000000000000000000000000..6fbe8665781211ca803feb8bf5a8c04fb0b969e8 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py @@ -0,0 +1,123 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ScaleTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops.bijectors import affine_scalar +from tensorflow.contrib.distributions.python.ops.bijectors import chain +from tensorflow.contrib.distributions.python.ops.bijectors import fill_triangular +from tensorflow.contrib.distributions.python.ops.bijectors import softplus +from tensorflow.contrib.distributions.python.ops.bijectors import transform_diagonal +from tensorflow.python.util import deprecation + +__all__ = [ + "ScaleTriL", +] + + +class ScaleTriL(chain.Chain): + """Transforms unconstrained vectors to TriL matrices with positive diagonal. + + This is implemented as a simple `tfb.Chain` of `tfb.FillTriangular` + followed by `tfb.TransformDiagonal`, and provided mostly as a + convenience. The default setup is somewhat opinionated, using a + Softplus transformation followed by a small shift (`1e-5`) which + attempts to avoid numerical issues from zeros on the diagonal. + + #### Examples + + ```python + tfb = tf.contrib.distributions.bijectors + b = tfb.ScaleTriL( + diag_bijector=tfb.Exp(), + diag_shift=None) + b.forward(x=[0., 0., 0.]) + # Result: [[1., 0.], + # [0., 1.]] + b.inverse(y=[[1., 0], + [.5, 2]]) + # Result: [log(2), .5, log(1)] + + # Define a distribution over PSD matrices of shape `[3, 3]`, + # with `1 + 2 + 3 = 6` degrees of freedom. + dist = tfd.TransformedDistribution( + tfd.Normal(tf.zeros(6), tf.ones(6)), + tfb.Chain([tfb.CholeskyOuterProduct(), tfb.ScaleTriL()])) + + # Using an identity transformation, ScaleTriL is equivalent to + # tfb.FillTriangular. + b = tfb.ScaleTriL( + diag_bijector=tfb.Identity(), + diag_shift=None) + + # For greater control over initialization, one can manually encode + # pre- and post- shifts inside of `diag_bijector`. + b = tfb.ScaleTriL( + diag_bijector=tfb.Chain([ + tfb.AffineScalar(shift=1e-3), + tfb.Softplus(), + tfb.AffineScalar(shift=0.5413)]), # softplus_inverse(1.) + # = log(expm1(1.)) = 0.5413 + diag_shift=None) + ``` + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, + diag_bijector=None, + diag_shift=1e-5, + validate_args=False, + name="scale_tril"): + """Instantiates the `ScaleTriL` bijector. + + Args: + diag_bijector: `Bijector` instance, used to transform the output diagonal + to be positive. + Default value: `None` (i.e., `tfb.Softplus()`). + diag_shift: Float value broadcastable and added to all diagonal entries + after applying the `diag_bijector`. Setting a positive + value forces the output diagonal entries to be positive, but + prevents inverting the transformation for matrices with + diagonal entries less than this value. + Default value: `1e-5` (i.e., no shift is applied). + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + Default value: `False` (i.e., arguments are not validated). + name: Python `str` name given to ops managed by this object. + Default value: `scale_tril`. + """ + + if diag_bijector is None: + diag_bijector = softplus.Softplus(validate_args=validate_args) + + if diag_shift is not None: + diag_bijector = chain.Chain([affine_scalar.AffineScalar(shift=diag_shift), + diag_bijector]) + + super(ScaleTriL, self).__init__( + [transform_diagonal.TransformDiagonal(diag_bijector=diag_bijector), + fill_triangular.FillTriangular()], + validate_args=validate_args, + name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py index 5df8c886315ff75cdc884e3b9b4665fb64bb109d..194b318fce31a13f84e7b664b58cebb24fc9a264 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -31,6 +32,14 @@ __all__ = [ class Sigmoid(bijector.Bijector): """Bijector which computes `Y = g(X) = 1 / (1 + exp(-X))`.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="sigmoid"): super(Sigmoid, self).__init__( forward_min_event_ndims=0, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py index 2a32e8abcde940b0056b0faf2955ec1b3bd71803..241fba2cb7ec33b7b02c1ca79051f1b826d7d2aa 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py @@ -26,12 +26,21 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ "SinhArcsinh", ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _sqrtx2p1(x): """Implementation of `sqrt(1 + x**2)` which is stable despite large `x`.""" return array_ops.where( @@ -88,6 +97,14 @@ class SinhArcsinh(bijector.Bijector): `Y approx 0.5 X**tailweight e**(sign(X) skewness * tailweight)`. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, skewness=None, tailweight=None, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py index f52b91550edff7390d8094a4508d862674e85d59..20ee0d340833d5c5275e2ab52a89dcdf7198add1 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -60,6 +61,14 @@ class SoftmaxCentered(bijector.Bijector): makes the (forward) image non-open and the theorem does not directly apply. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="softmax_centered"): diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py index 96a938c803418ff818f9c531754b47ba1eb8667a..3df84ef8b04c2c8f6be91ecd1c972ad1484b4285 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -80,6 +81,14 @@ class Softplus(bijector.Bijector): "hinge_softness": ( "Nonzero floating point `Tensor`. Controls the softness of what " "would otherwise be a kink at the origin. Default is 1.0")}) + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, hinge_softness=None, validate_args=False, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py b/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py index b4a658c171b8313358754228aabbfa4bf93fd84d..f96a4bb01de59a21107b9e7c14f929e13e358ac9 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py @@ -22,6 +22,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -51,6 +52,14 @@ class Softsign(bijector.Bijector): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="softsign"): super(Softsign, self).__init__( forward_min_event_ndims=0, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/square.py b/tensorflow/contrib/distributions/python/ops/bijectors/square.py index 2ccfdc95970e387e708603e2614ad29fb6a18db3..294460a80f6209797831ea361e64efe677f71e59 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/square.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/square.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -49,6 +50,14 @@ class Square(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="square"): """Instantiates the `Square` bijector. @@ -81,4 +90,3 @@ class Square(bijector.Bijector): is_valid = check_ops.assert_non_negative( t, message="All elements must be non-negative.") return control_flow_ops.with_dependencies([is_valid], t) - diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py b/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py new file mode 100644 index 0000000000000000000000000000000000000000..9b7a3b026b8dcc31bed49c489d77b9c184f463cb --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py @@ -0,0 +1,111 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TransformDiagonal bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation + +__all__ = [ + "TransformDiagonal", +] + + +class TransformDiagonal(bijector.Bijector): + """Applies a Bijector to the diagonal of a matrix. + + #### Example + + ```python + b = tfb.TransformDiagonal(diag_bijector=tfb.Exp()) + + b.forward([[1., 0.], + [0., 1.]]) + # ==> [[2.718, 0.], + [0., 2.718]] + ``` + + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, + diag_bijector, + validate_args=False, + name="transform_diagonal"): + """Instantiates the `TransformDiagonal` bijector. + + Args: + diag_bijector: `Bijector` instance used to transform the diagonal. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._diag_bijector = diag_bijector + super(TransformDiagonal, self).__init__( + forward_min_event_ndims=2, + inverse_min_event_ndims=2, + validate_args=validate_args, + name=name) + + def _forward(self, x): + diag = self._diag_bijector.forward(array_ops.matrix_diag_part(x)) + return array_ops.matrix_set_diag(x, diag) + + def _inverse(self, y): + diag = self._diag_bijector.inverse(array_ops.matrix_diag_part(y)) + return array_ops.matrix_set_diag(y, diag) + + def _forward_log_det_jacobian(self, x): + # We formulate the Jacobian with respect to the flattened matrices + # `vec(x)` and `vec(y)`. Suppose for notational convenience that + # the first `n` entries of `vec(x)` are the diagonal of `x`, and + # the remaining `n**2-n` entries are the off-diagonals in + # arbitrary order. Then the Jacobian is a block-diagonal matrix, + # with the Jacobian of the diagonal bijector in the first block, + # and the identity Jacobian for the remaining entries (since this + # bijector acts as the identity on non-diagonal entries): + # + # J_vec(x) (vec(y)) = + # ------------------------------- + # | J_diag(x) (diag(y)) 0 | n entries + # | | + # | 0 I | n**2-n entries + # ------------------------------- + # n n**2-n + # + # Since the log-det of the second (identity) block is zero, the + # overall log-det-jacobian is just the log-det of first block, + # from the diagonal bijector. + # + # Note that for elementwise operations (exp, softplus, etc) the + # first block of the Jacobian will itself be a diagonal matrix, + # but our implementation does not require this to be true. + return self._diag_bijector.forward_log_det_jacobian( + array_ops.matrix_diag_part(x), event_ndims=1) + + def _inverse_log_det_jacobian(self, y): + return self._diag_bijector.inverse_log_det_jacobian( + array_ops.matrix_diag_part(y), event_ndims=1) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py index a22560fe80298b762795e7b0e7aea2db55823065..8903a70d98ae144731b12047e5074d0450b59378 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -47,6 +48,14 @@ class Weibull(bijector.Bijector): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, scale=1., concentration=1., diff --git a/tensorflow/contrib/distributions/python/ops/binomial.py b/tensorflow/contrib/distributions/python/ops/binomial.py index e4944beedcbca09b5eabd4daf1445ce4503b1c80..b349e5966dd750fdf96c0b211dce02658c9400b7 100644 --- a/tensorflow/contrib/distributions/python/ops/binomial.py +++ b/tensorflow/contrib/distributions/python/ops/binomial.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation _binomial_sample_note = """ @@ -42,6 +43,14 @@ to integer values. """ +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _bdtr(k, n, p): """The binomial cumulative distribution function. @@ -130,6 +139,14 @@ class Binomial(distribution.Distribution): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, total_count, logits=None, diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py index 23b6a83c17d58652001543047febeebabba0c69f..cb5223b0557080e10bf24c3e1cb432f15fd5e7e3 100644 --- a/tensorflow/contrib/distributions/python/ops/cauchy.py +++ b/tensorflow/contrib/distributions/python/ops/cauchy.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.util import deprecation __all__ = [ "Cauchy", @@ -92,6 +93,14 @@ class Cauchy(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, diff --git a/tensorflow/contrib/distributions/python/ops/chi2.py b/tensorflow/contrib/distributions/python/ops/chi2.py index 686ae1ba74641e2b7b76667e512fa6453477a8da..e9a7b39070f3d76693ad54852ed0847a0980d2a6 100644 --- a/tensorflow/contrib/distributions/python/ops/chi2.py +++ b/tensorflow/contrib/distributions/python/ops/chi2.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import gamma +from tensorflow.python.util import deprecation __all__ = [ @@ -63,6 +64,14 @@ class Chi2(gamma.Gamma): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, validate_args=False, @@ -114,6 +123,14 @@ class Chi2(gamma.Gamma): class Chi2WithAbsDf(Chi2): """Chi2 with parameter transform `df = floor(abs(df))`.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, validate_args=False, diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py index c44c76a133817640449ba126bb8ca25abadba5e6..ad853ee293f86565c1af601214522f53d936b70a 100644 --- a/tensorflow/contrib/distributions/python/ops/deterministic.py +++ b/tensorflow/contrib/distributions/python/ops/deterministic.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.util import deprecation __all__ = [ "Deterministic", @@ -43,6 +44,14 @@ __all__ = [ class _BaseDeterministic(distribution.Distribution): """Base class for Deterministic distributions.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, atol=None, @@ -203,6 +212,14 @@ class Deterministic(_BaseDeterministic): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, atol=None, @@ -308,6 +325,14 @@ class VectorDeterministic(_BaseDeterministic): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, atol=None, diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index 289e1d50e1146a641c0cc433ece3465aed73b1c2..6959b3e8775d2dd488b4ee3252d143ef376d58f9 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -21,12 +21,19 @@ from __future__ import print_function from tensorflow.contrib import linalg from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib + +# The following two lines are redundant, in a sense. The first enables +# good coding practice *within* this file (`util.prefer_static_value` +# rather than `prefer_static_value`). The second ensures that users +# also get the core utils when they import this file. +from tensorflow.python.ops.distributions import util from tensorflow.python.ops.distributions.util import * # pylint: disable=wildcard-import @@ -484,3 +491,75 @@ def pad_mixture_dimensions(x, mixture_distribution, categorical_distribution, def static_value(x): """Returns the static value of a `Tensor` or `None`.""" return tensor_util.constant_value(ops.convert_to_tensor(x)) + + +def move_dimension(x, source_idx, dest_idx): + """Move a single tensor dimension within its shape. + + This is a special case of `tf.transpose()`, which applies + arbitrary permutations to tensor dimensions. + + Args: + x: Tensor of rank `ndims`. + source_idx: Integer index into `x.shape` (negative indexing is + supported). + dest_idx: Integer index into `x.shape` (negative indexing is + supported). + + Returns: + x_perm: Tensor of rank `ndims`, in which the dimension at original + index `source_idx` has been moved to new index `dest_idx`, with + all other dimensions retained in their original order. + + Example: + + ```python + x = tf.placeholder(shape=[200, 30, 4, 1, 6]) + x_perm = _move_dimension(x, 1, 1) # no-op + x_perm = _move_dimension(x, 0, 3) # result shape [30, 4, 1, 200, 6] + x_perm = _move_dimension(x, 0, -2) # equivalent to previous + x_perm = _move_dimension(x, 4, 2) # result shape [200, 30, 6, 4, 1] + ``` + """ + ndims = util.prefer_static_rank(x) + if isinstance(source_idx, int): + dtype = dtypes.int32 + else: + dtype = dtypes.as_dtype(source_idx.dtype) + + # Handle negative indexing. Since ndims might be dynamic, this makes + # source_idx and dest_idx also possibly dynamic. + if source_idx < 0: + source_idx = ndims + source_idx + if dest_idx < 0: + dest_idx = ndims + dest_idx + + # Construct the appropriate permutation of dimensions, depending + # whether the source is before or after the destination. + def move_left_permutation(): + return util.prefer_static_value( + array_ops.concat([ + math_ops.range(0, dest_idx, dtype=dtype), + [source_idx], + math_ops.range(dest_idx, source_idx, dtype=dtype), + math_ops.range(source_idx+1, ndims, dtype=dtype)], axis=0)) + + def move_right_permutation(): + return util.prefer_static_value( + array_ops.concat([ + math_ops.range(0, source_idx, dtype=dtype), + math_ops.range(source_idx+1, dest_idx+1, dtype=dtype), + [source_idx], + math_ops.range(dest_idx+1, ndims, dtype=dtype)], axis=0)) + + def x_permuted(): + return array_ops.transpose( + x, perm=smart_cond.smart_cond(source_idx < dest_idx, + move_right_permutation, + move_left_permutation)) + + # One final conditional to handle the special case where source + # and destination indices are equal. + return smart_cond.smart_cond(math_ops.equal(source_idx, dest_idx), + lambda: x, + x_permuted) diff --git a/tensorflow/contrib/distributions/python/ops/estimator.py b/tensorflow/contrib/distributions/python/ops/estimator.py index 98edd337fe02ffbf53c6ecd9ebda9424231ea2fe..bdec6527d5378d6e86aa8e6279cc6ee672083e56 100644 --- a/tensorflow/contrib/distributions/python/ops/estimator.py +++ b/tensorflow/contrib/distributions/python/ops/estimator.py @@ -23,6 +23,7 @@ from tensorflow.contrib.learn.python.learn.estimators.head import _RegressionHea from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.util import deprecation __all__ = [ @@ -30,6 +31,14 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def estimator_head_distribution_regression(make_distribution_fn, label_dimension=1, logits_dimension=None, @@ -77,6 +86,14 @@ def estimator_head_distribution_regression(make_distribution_fn, class _DistributionRegressionHead(_RegressionHead): """Creates a _RegressionHead instance from an arbitrary `Distribution`.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, make_distribution_fn, label_dimension, diff --git a/tensorflow/contrib/distributions/python/ops/geometric.py b/tensorflow/contrib/distributions/python/ops/geometric.py index e1e42ee95d200df30c2c8a53a89cb5b7e9c4d17c..d62f024aa2a081f0ec231015af1f26a8851518e9 100644 --- a/tensorflow/contrib/distributions/python/ops/geometric.py +++ b/tensorflow/contrib/distributions/python/ops/geometric.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class Geometric(distribution.Distribution): @@ -55,6 +56,14 @@ class Geometric(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, logits=None, probs=None, diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py index 9d94fd11c62ce6ecd3d7daee35447bece2b4b2fb..acdea4d61d3ada7e9f4f0aa7bc58c5643db2802b 100644 --- a/tensorflow/contrib/distributions/python/ops/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/gumbel.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.util import deprecation class _Gumbel(distribution.Distribution): @@ -96,6 +97,14 @@ class _Gumbel(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py index 9c96254d1c0a593b955231132330931ff5f4ad07..b02c4031069191592b8acc1a90313450f98af6d7 100644 --- a/tensorflow/contrib/distributions/python/ops/half_normal.py +++ b/tensorflow/contrib/distributions/python/ops/half_normal.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import special_math +from tensorflow.python.util import deprecation __all__ = [ @@ -85,6 +86,14 @@ class HalfNormal(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, scale, validate_args=False, diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py index cd6eaa8407477b4ed92f169bc0d2d80644d7c956..0672702b96c1eb81c176774554df3f5922a0319e 100644 --- a/tensorflow/contrib/distributions/python/ops/independent.py +++ b/tensorflow/contrib/distributions/python/ops/independent.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib from tensorflow.python.ops.distributions import kullback_leibler +from tensorflow.python.util import deprecation class Independent(distribution_lib.Distribution): @@ -94,6 +95,14 @@ class Independent(distribution_lib.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__( self, distribution, reinterpreted_batch_ndims=None, validate_args=False, name=None): @@ -258,6 +267,14 @@ class Independent(distribution_lib.Distribution): @kullback_leibler.RegisterKL(Independent, Independent) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _kl_independent(a, b, name="kl_independent"): """Batched KL divergence `KL(a || b)` for Independent distributions. diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index 208057b34db2881b5c9c2adb102d02a87a333007..70d050d7a647b38928ddb1c788db0e6957ac0f03 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -95,6 +96,14 @@ class InverseGamma(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, concentration, rate, @@ -274,6 +283,14 @@ class InverseGamma(distribution.Distribution): class InverseGammaWithSoftplusConcentrationRate(InverseGamma): """`InverseGamma` with softplus of `concentration` and `rate`.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, concentration, rate, diff --git a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py index 0ff989fc952c6fb3f54dad9a943eb36a0494a3be..e3712dd84e36609d6bba4a5a39866046c0c8d1d8 100644 --- a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py +++ b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import special_math_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import uniform +from tensorflow.python.util import deprecation __all__ = [ "Kumaraswamy", @@ -40,6 +41,14 @@ _kumaraswamy_sample_note = """Note: `x` must have dtype `self.dtype` and be in `[0, 1].` It must have a shape compatible with `self.batch_shape()`.""" +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _harmonic_number(x): """Compute the harmonic number from its analytic continuation. @@ -123,6 +132,14 @@ class Kumaraswamy(transformed_distribution.TransformedDistribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, concentration1=None, concentration0=None, diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py index 27aa863440574eb0cdb5c7ae326e877d472999ad..02e3bad51ee48188acf83cb09359861c9e6932c7 100644 --- a/tensorflow/contrib/distributions/python/ops/logistic.py +++ b/tensorflow/contrib/distributions/python/ops/logistic.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.util import deprecation class Logistic(distribution.Distribution): @@ -91,6 +92,14 @@ class Logistic(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py index bfb53a06c011cec60cf5b2132e4b1106128a1ece..3b7114ef067c0aaede23fff04c40d1dc6e830f1c 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture.py +++ b/tensorflow/contrib/distributions/python/ops/mixture.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import categorical from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class Mixture(distribution.Distribution): @@ -66,6 +67,14 @@ class Mixture(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, cat, components, diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py index 112eefd3691815ead19d59bc3aef5909b27ed169..8ffee940d03c9a5204f2ac6f7acd9ea482adae1a 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py +++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class MixtureSameFamily(distribution.Distribution): @@ -95,6 +96,14 @@ class MixtureSameFamily(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, mixture_distribution, components_distribution, @@ -321,6 +330,14 @@ class MixtureSameFamily(distribution.Distribution): return x +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _outer_squared_difference(x, y): """Convenience function analogous to tf.squared_difference.""" z = x - y diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py index d2beb2aff0481eb4ec3a3abbf44fad5efff8eedd..cd0c282ba6cebf784261a4e821f36ce4eed98fe0 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py @@ -22,6 +22,7 @@ from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop from tensorflow.python.framework import ops from tensorflow.python.ops import nn +from tensorflow.python.util import deprecation __all__ = [ @@ -134,6 +135,14 @@ class MultivariateNormalDiag( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, @@ -218,6 +227,14 @@ class MultivariateNormalDiag( class MultivariateNormalDiagWithSoftplusScale(MultivariateNormalDiag): """MultivariateNormalDiag with `diag_stddev = softplus(diag_stddev)`.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale_diag, diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py index 5117379b047f5e510a8a1a5490ddf76ee93d9d74..d8401801f21afbe8fd042053c6a38a31a2539438 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py @@ -22,6 +22,7 @@ from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation __all__ = [ @@ -141,6 +142,14 @@ class MultivariateNormalDiagPlusLowRank( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py index 57f47db50c496f1e3e80d8177560b1bab594eb56..dbc4c1b3dc956641f3e38ffafe3a3410bd3e2097 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops +from tensorflow.python.util import deprecation __all__ = [ @@ -112,6 +113,14 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, covariance_matrix=None, diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py index 6a0383db02555274239ee0b1845f24a705270d84..efe5a6d0d99ca8fa9e0274049423bb3c4eef2d6f 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -27,6 +27,7 @@ from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.linalg import linalg +from tensorflow.python.util import deprecation __all__ = [ @@ -133,6 +134,14 @@ class MultivariateNormalLinearOperator( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale=None, @@ -266,6 +275,14 @@ class MultivariateNormalLinearOperator( @kullback_leibler.RegisterKL(MultivariateNormalLinearOperator, MultivariateNormalLinearOperator) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _kl_brute_force(a, b, name=None): """Batched KL divergence `KL(a || b)` for multivariate Normals. diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py index c809ef3c1cb5b8b9cd892b98d81e57710807d0aa..d9110947ecdbba1a63669573f46db17b02e512ab 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py @@ -22,6 +22,7 @@ from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop from tensorflow.python.framework import ops from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -134,6 +135,14 @@ class MultivariateNormalTriL( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_tril=None, diff --git a/tensorflow/contrib/distributions/python/ops/negative_binomial.py b/tensorflow/contrib/distributions/python/ops/negative_binomial.py index 2bd11e24b315e044624344580108a232d1b6da89..6acfc5746a0cc20e916de81b71f90e08d8d91ad5 100644 --- a/tensorflow/contrib/distributions/python/ops/negative_binomial.py +++ b/tensorflow/contrib/distributions/python/ops/negative_binomial.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class NegativeBinomial(distribution.Distribution): @@ -51,6 +52,14 @@ class NegativeBinomial(distribution.Distribution): * `n!` is the factorial of `n`. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, total_count, logits=None, diff --git a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py index 3e44c10fab726ad1299cc852a5e1391fecb8b390..214c6dca4a7f2b4cd6242e1b7ca78be9eeffb851 100644 --- a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class OneHotCategorical(distribution.Distribution): @@ -83,6 +84,14 @@ class OneHotCategorical(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__( self, logits=None, @@ -226,13 +235,21 @@ class OneHotCategorical(distribution.Distribution): return x return control_flow_ops.with_dependencies([ check_ops.assert_non_positive(x), - distribution_util.assert_close( + check_ops.assert_near( array_ops.zeros([], dtype=self.dtype), math_ops.reduce_logsumexp(x, axis=[-1])), ], x) @kullback_leibler.RegisterKL(OneHotCategorical, OneHotCategorical) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _kl_categorical_categorical(a, b, name=None): """Calculate the batched KL divergence KL(a || b) with a, b OneHotCategorical. diff --git a/tensorflow/contrib/distributions/python/ops/poisson.py b/tensorflow/contrib/distributions/python/ops/poisson.py index 04de8106ee0c06f4bc888964e053eb3123f3dab3..3d055085cc7386e57a71aa310458b7666bb9a396 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson.py +++ b/tensorflow/contrib/distributions/python/ops/poisson.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ "Poisson", @@ -65,6 +66,14 @@ class Poisson(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, rate=None, log_rate=None, diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py index 7b10ba998f0ceac37571524ce858bbd4c87455fe..7a7ad1be35b80ff0f000181ea0778ab282a8220f 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py +++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py @@ -33,6 +33,7 @@ from tensorflow.python.ops.distributions import categorical as categorical_lib from tensorflow.python.ops.distributions import distribution as distribution_lib from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.ops.distributions import transformed_distribution as transformed_lib +from tensorflow.python.util import deprecation __all__ = [ @@ -42,6 +43,14 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def quadrature_scheme_lognormal_gauss_hermite( loc, scale, quadrature_size, validate_args=False, name=None): # pylint: disable=unused-argument @@ -85,6 +94,14 @@ def quadrature_scheme_lognormal_gauss_hermite( return grid, probs +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def quadrature_scheme_lognormal_quantiles( loc, scale, quadrature_size, validate_args=False, name=None): @@ -214,6 +231,14 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): validate_args=True) """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, @@ -417,6 +442,14 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): axis=[-2, -1]) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def concat_vectors(*args): """Concatenates input vectors, statically if possible.""" args_ = [distribution_util.static_value(x) for x in args] diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py index 5ac6c34b538016af376f53aa5a889e78c1f65f5f..ef3bdfa75fcaa8df17db1238ceadadf788601356 100644 --- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py @@ -27,10 +27,19 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distributions from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = ["QuantizedDistribution"] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _logsum_expbig_minus_expsmall(big, small): """Stable evaluation of `Log[exp{big} - exp{small}]`. @@ -228,6 +237,14 @@ class QuantizedDistribution(distributions.Distribution): https://arxiv.org/abs/1711.10433 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, distribution, low=None, diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py index 4182ca2b56ea80dba71787b006a1652e0f979694..7e1f64dc425e6a576bfbe1bb456901fddfac26e1 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py @@ -19,15 +19,16 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distributions.python.ops import logistic +from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid # Bijectors must be directly imported because `remove_undocumented` prevents # individual file imports. -from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class RelaxedBernoulli(transformed_distribution.TransformedDistribution): @@ -131,6 +132,14 @@ class RelaxedBernoulli(transformed_distribution.TransformedDistribution): Gumbel-Softmax. 2016. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, temperature, logits=None, diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py index 5414f347cd65e2d3327d1934cbc7a91e7f780fc5..25aaac379a7c54c832bdcf962e16f339522d61fc 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class ExpRelaxedOneHotCategorical(distribution.Distribution): @@ -125,6 +126,14 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): A Continuous Relaxation of Discrete Random Variables. 2016. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__( self, temperature, @@ -290,7 +299,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): return x return control_flow_ops.with_dependencies([ check_ops.assert_non_positive(x), - distribution_util.assert_close( + check_ops.assert_near( array_ops.zeros([], dtype=self.dtype), math_ops.reduce_logsumexp(x, axis=[-1])), ], x) @@ -368,6 +377,14 @@ class RelaxedOneHotCategorical( A Continuous Relaxation of Discrete Random Variables. 2016. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__( self, temperature, diff --git a/tensorflow/contrib/distributions/python/ops/shape.py b/tensorflow/contrib/distributions/python/ops/shape.py index 6a7f28713acefd2285b07a212e2e47a6db1ae5e1..4f348be2806aa3ade7c1ea2a7bc68ca26db6447f 100644 --- a/tensorflow/contrib/distributions/python/ops/shape.py +++ b/tensorflow/contrib/distributions/python/ops/shape.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class _DistributionShape(object): @@ -166,6 +167,14 @@ class _DistributionShape(object): "free," i.e., during graph construction. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, batch_ndims=None, event_ndims=None, diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py index a764544932cea8a624820153e383595fec9d7fc6..a9d0fb4ccfb1803873f7fe17089f3e7c7f10f4b7 100644 --- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.util import deprecation __all__ = [ "SinhArcsinh", @@ -94,6 +95,14 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index 8d4914e16cd3748e81e3d9b3be8b35f64a1c6f0d..ece03fe4aab3cc3046e0958d883ca9388517b94b 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -40,6 +40,7 @@ from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib from tensorflow.python.ops.linalg import linear_operator_full_matrix as linop_full_lib from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib from tensorflow.python.ops.linalg import linear_operator_lower_triangular as linop_tril_lib +from tensorflow.python.util import deprecation __all__ = [ @@ -49,6 +50,14 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def quadrature_scheme_softmaxnormal_gauss_hermite( normal_loc, normal_scale, quadrature_size, validate_args=False, name=None): @@ -111,6 +120,14 @@ def quadrature_scheme_softmaxnormal_gauss_hermite( return grid, probs +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def quadrature_scheme_softmaxnormal_quantiles( normal_loc, normal_scale, quadrature_size, validate_args=False, name=None): @@ -318,6 +335,14 @@ class VectorDiffeomixture(distribution_lib.Distribution): https://arxiv.org/abs/1801.03080 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, mix_loc, temperature, @@ -779,6 +804,14 @@ class VectorDiffeomixture(distribution_lib.Distribution): return array_ops.reshape(p, shape=expand_shape) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def maybe_check_quadrature_param(param, name, validate_args): """Helper which checks validity of `loc` and `scale` init args.""" with ops.name_scope(name="check_" + name, values=[param]): @@ -812,6 +845,14 @@ def maybe_check_quadrature_param(param, name, validate_args): return param +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def determine_batch_event_shapes(grid, endpoint_affine): """Helper to infer batch_shape and event_shape.""" with ops.name_scope(name="determine_batch_event_shapes"): @@ -850,6 +891,14 @@ def determine_batch_event_shapes(grid, endpoint_affine): return batch_shape, batch_shape_tensor, event_shape, event_shape_tensor +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def interpolate_loc(grid, loc): """Helper which interpolates between two locs.""" if len(loc) != 2: @@ -876,6 +925,14 @@ def interpolate_loc(grid, loc): return [x[..., k] for k in range(deg)] # list(shape:[B, e]) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def interpolate_scale(grid, scale): """Helper which interpolates between two scales.""" if len(scale) != 2: @@ -892,6 +949,14 @@ def interpolate_scale(grid, scale): ])[0] for q in range(deg)] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def linop_scale(w, op): # We assume w > 0. (This assumption only relates to the is_* attributes.) with ops.name_scope("linop_scale", values=[w]): @@ -927,6 +992,14 @@ def linop_scale(w, op): "Unsupported Linop type ({})".format(type(op).__name__)) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def concat_vectors(*args): """Concatenates input vectors, statically if possible.""" args_ = [distribution_util.static_value(x) for x in args] @@ -935,6 +1008,14 @@ def concat_vectors(*args): return [val for vec in args_ for val in vec] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def add(x, y): """Adds inputs; interprets `None` as zero.""" if x is None: @@ -944,11 +1025,27 @@ def add(x, y): return x + y +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def vec_osquare(x): """Computes the outer-product of a (batch of) vector, i.e., x.T x.""" return x[..., :, array_ops.newaxis] * x[..., array_ops.newaxis, :] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def softmax(x, axis, name=None): """Equivalent to tf.nn.softmax but works around b/70297725.""" with ops.name_scope(name, "softmax", [x, axis]): diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py index a75b3f3df1f2867f214f47051fa358b79a52a35e..73356a3625c9a1aa15af5b6c1cf2ccb0c514b39a 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import vector_exponential_linear_operator as vector_exponential_linop from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation __all__ = [ @@ -116,6 +117,14 @@ class VectorExponentialDiag( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py index a7d4c55be93f6190ae4d6976030190f27dcfe48f..9a47b4855763a25b484ad04a3415d191f19256f7 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import exponential from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.linalg import linalg +from tensorflow.python.util import deprecation __all__ = ["VectorExponentialLinearOperator"] @@ -138,6 +139,14 @@ class VectorExponentialLinearOperator( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale=None, diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py index 4a53e7a621f27382d2995798f724392d34459670..e68ddc569c95ff63760b4b2f6d7a92f17240a558 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import vector_laplace_linear_operator as vector_laplace_linop from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation __all__ = [ @@ -151,6 +152,14 @@ class VectorLaplaceDiag( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py index 0566e04fece6f9ca0de6903ce5c424eccbc003cd..3923161a332a77e4eaab8d65d96fd8c278c872ec 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import laplace from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.linalg import linalg +from tensorflow.python.util import deprecation __all__ = [ @@ -154,6 +155,14 @@ class VectorLaplaceLinearOperator( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale=None, diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py index bb33cd0762a368eb7e53f1623ede9231e80f0b14..49ffff24caec8d6c525f65f06796d10548d5ec40 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.util import deprecation __all__ = [ "VectorSinhArcsinhDiag", @@ -95,6 +96,14 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py index 21f84dcbdea8b422dd45fadeac1bb8b2804c551f..f289b39e51aff36780541a0545ed9e6cfe21dd4e 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py +++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions import student_t from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.util import deprecation class _VectorStudentT(transformed_distribution.TransformedDistribution): @@ -121,6 +122,14 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, loc=None, diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py index 88d4280759da7ca685056f4d41cf8dc51393c9f3..f1accaaa4c920344608015c792a2c3606de1337f 100644 --- a/tensorflow/contrib/distributions/python/ops/wishart.py +++ b/tensorflow/contrib/distributions/python/ops/wishart.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.util import deprecation __all__ = [ "WishartCholesky", @@ -73,6 +74,14 @@ class _WishartLinearOperator(distribution.Distribution): this class. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, scale_operator, @@ -501,6 +510,14 @@ class WishartCholesky(_WishartLinearOperator): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, scale, @@ -617,6 +634,14 @@ class WishartFull(_WishartLinearOperator): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, scale, diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index 1d9371c7ac405dbf0ec40210270b90f2cf9b9a25..12155a459c29c353c57679c407e7dda25047a35c 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -11,8 +11,12 @@ py_library( "//tensorflow/contrib/eager/python/examples/l2hmc:neural_nets", "//tensorflow/contrib/eager/python/examples/linear_regression", "//tensorflow/contrib/eager/python/examples/resnet50", + "//tensorflow/contrib/eager/python/examples/revnet", + "//tensorflow/contrib/eager/python/examples/revnet:config", "//tensorflow/contrib/eager/python/examples/rnn_colorbot", "//tensorflow/contrib/eager/python/examples/rnn_ptb", + "//tensorflow/contrib/eager/python/examples/sagan", + "//tensorflow/contrib/eager/python/examples/sagan:config", "//tensorflow/contrib/eager/python/examples/spinn:data", ], ) diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py index 98b4ce1b26acf2d934ed7abf6452d200cc9e7e80..729d8525fab31ee214178ca1bcb18dbd069f767a 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py +++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py @@ -57,11 +57,6 @@ class Dynamics(tf.keras.Model): self.eps = tfe.Variable( initial_value=eps, name="eps", dtype=tf.float32, trainable=True) - # TODO(lxuechen): Remove this after model.add_weight is in place - self.vars_not_in_layers = [self.eps] - self.vars_not_in_layers += self.position_fn.vars_not_in_layers - self.vars_not_in_layers += self.momentum_fn.vars_not_in_layers - def apply_transition(self, position): """Propose a new state and perform the accept or reject step.""" @@ -290,86 +285,35 @@ class Dynamics(tf.keras.Model): return grad -# Defining loss and grads for training -def compute_loss(x, dynamics, scale=.1, eps=1e-4): - """Compute loss defined in equation (8).""" - - z = tf.random_normal(tf.shape(x)) - x_, _, x_accept_prob, x_out = dynamics.apply_transition(x) - z_, _, z_accept_prob, _ = dynamics.apply_transition(z) - - # Add eps for numerical stability; following released impl - x_loss = tf.reduce_sum((x - x_)**2, axis=1) * x_accept_prob + eps - z_loss = tf.reduce_sum((z - z_)**2, axis=1) * z_accept_prob + eps - - loss = tf.reduce_mean( - (1. / x_loss + 1. / z_loss) * scale - (x_loss + z_loss) / scale, axis=0) - - return loss, x_out - - -def loss_and_grads(x, dynamics): - """Obtain loss value and gradients.""" - - with tf.GradientTape() as tape: - loss_val, x_out = compute_loss(x, dynamics) - - vars_ = dynamics.variables + dynamics.vars_not_in_layers - grads = tape.gradient(loss_val, vars_) - - return loss_val, grads, x_out - - -def warmup(dynamics, optimizer, n_iters=1, n_samples=200): - """Warmup optimization to reduce overhead.""" - - samples = tf.random_normal( - shape=[n_samples, dynamics.x_dim], dtype=tf.float32) - - for _ in range(n_iters): - _, grads, samples = loss_and_grads(samples, dynamics) - vars_ = dynamics.variables + dynamics.vars_not_in_layers - optimizer.apply_gradients(zip(grads, vars_)) - - -def fit(dynamics, - optimizer, - n_samples=200, - n_iters=5000, - verbose=True, - logdir=None): - """Fit L2HMC sampler with given log-likelihood function.""" - - if logdir: - summary_writer = tf.contrib.summary.create_file_writer(logdir) +# Examples of unnormalized log density/probabilities +def get_scg_energy_fn(): + """Get energy function for 2d strongly correlated Gaussian.""" - samples = tf.random_normal( - shape=[n_samples, dynamics.x_dim], dtype=tf.float32) + # Avoid recreating tf constants on each invocation of gradients + mu = tf.constant([0., 0.]) + sigma = tf.constant([[50.05, -49.95], [-49.95, 50.05]]) + sigma_inv = tf.matrix_inverse(sigma) - tf.train.get_or_create_global_step() - for i in range(n_iters): - loss, grads, samples = loss_and_grads(samples, dynamics) - # TODO(lxuechen): Proper learning rate decay - grads_ = [grad * .96**(i // 1000) for grad in grads] - vars_ = dynamics.variables + dynamics.vars_not_in_layers - optimizer.apply_gradients( - zip(grads_, vars_), global_step=tf.train.get_global_step()) + def energy(x): + """Unnormalized log density/energy of 2d strongly correlated Gaussian.""" - if verbose: - print("Iteration %d: loss %.4f" % (i, loss)) + xmmu = x - mu + return .5 * tf.diag_part( + tf.matmul(tf.matmul(xmmu, sigma_inv), tf.transpose(xmmu))) - if logdir: - with summary_writer.as_default(): - with tf.contrib.summary.always_record_summaries(): - tf.contrib.summary.scalar("loss", loss) + return energy -def get_scg_energy_fn(): +def get_multivariate_gaussian_energy_fn(x_dim=2): """Get energy function for 2d strongly correlated Gaussian.""" - # Avoid recreating tf constants on each invocation of gradients - mu = tf.constant([0., 0.]) - sigma = tf.constant([[50.05, -49.95], [-49.95, 50.05]]) + mu = tf.random_normal(shape=[x_dim]) + # Lower triangularize and positive diagonal + l = tf.sigmoid( + tf.matrix_band_part(tf.random_normal(shape=[x_dim, x_dim]), -1, 0)) + # Exploit Cholesky decomposition + sigma = tf.matmul(l, tf.transpose(l)) + sigma *= 100. # Small covariance causes extreme numerical instability sigma_inv = tf.matrix_inverse(sigma) def energy(x): diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py index 522a7c9380131b6eddd241e2450bae248ad15ccf..e33b4cae4c73388dfd78542c9907953f137ad710 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py +++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py @@ -32,16 +32,83 @@ def get_default_hparams(): n_samples=200, n_steps=10, eps=.1, - n_iters=5, - learning_rate=.001, - n_warmup_iters=1) + n_iters=10, + learning_rate=.0003, + n_warmup_iters=3) + + +# Relevant functions for benchmarking +def compute_loss(dynamics, x, scale=.1, eps=1e-4): + """Compute loss defined in equation (8).""" + + z = tf.random_normal(tf.shape(x)) + x_, _, x_accept_prob, x_out = dynamics.apply_transition(x) + z_, _, z_accept_prob, _ = dynamics.apply_transition(z) + + # Add eps for numerical stability; following released impl + x_loss = tf.reduce_sum((x - x_)**2, axis=1) * x_accept_prob + eps + z_loss = tf.reduce_sum((z - z_)**2, axis=1) * z_accept_prob + eps + + loss = tf.reduce_mean( + (1. / x_loss + 1. / z_loss) * scale - (x_loss + z_loss) / scale, axis=0) + + return loss, x_out + + +def loss_and_grads(dynamics, x, loss_fn=compute_loss): + """Obtain loss value and gradients.""" + + with tf.GradientTape() as tape: + loss_val, x_out = loss_fn(dynamics, x) + grads = tape.gradient(loss_val, dynamics.variables) + + return loss_val, grads, x_out + + +def warmup(dynamics, optimizer, n_iters=1, n_samples=200, loss_fn=compute_loss): + """Warmup optimization to reduce overhead.""" + + samples = tf.random_normal( + shape=[n_samples, dynamics.x_dim], dtype=tf.float32) + + for _ in range(n_iters): + _, grads, samples = loss_and_grads(dynamics, samples, loss_fn=loss_fn) + optimizer.apply_gradients(zip(grads, dynamics.variables)) + + +def fit(dynamics, + samples, + optimizer, + loss_fn=compute_loss, + n_iters=5000, + verbose=True, + logdir=None, + decay_lr=True): + """Fit L2HMC sampler with given log-likelihood function.""" + + if logdir: + summary_writer = tf.contrib.summary.create_file_writer(logdir) + + for i in range(n_iters): + loss, grads, samples = loss_and_grads(dynamics, samples, loss_fn=loss_fn) + # TODO(lxuechen): Proper learning rate decay + if decay_lr: + grads = [grad * .96**(i // 1000) for grad in grads] + optimizer.apply_gradients(zip(grads, dynamics.variables)) + if verbose: + print("Iteration %d: loss %.4f" % (i, loss)) + + if logdir: + with summary_writer.as_default(): + with tf.contrib.summary.always_record_summaries(): + tf.contrib.summary.scalar("loss", loss) class L2hmcTest(tf.test.TestCase): """Unit tests for l2hmc in both eager and graph mode.""" - def testComputeLoss(self): - """Testing function l2hmc.compute_loss in both graph and eager mode.""" + def test_apply_transition(self): + """Testing function `Dynamics.apply_transition` in graph and eager mode.""" # Eager mode testing hparams = get_default_hparams() @@ -51,12 +118,12 @@ class L2hmcTest(tf.test.TestCase): n_steps=hparams.n_steps, eps=hparams.eps) samples = tf.random_normal(shape=[hparams.n_samples, hparams.x_dim]) - loss, x_out = l2hmc.compute_loss(samples, dynamics) + x_, v_, x_accept_prob, x_out = dynamics.apply_transition(samples) - # Check shape and numerical stability + self.assertEqual(x_.shape, v_.shape) self.assertEqual(x_out.shape, samples.shape) - self.assertEqual(loss.shape, []) - self.assertAllClose(loss.numpy(), loss.numpy(), rtol=1e-5) + self.assertEqual(x_.shape, x_out.shape) + self.assertEqual(x_accept_prob.shape, (hparams.n_samples,)) # Graph mode testing with tf.Graph().as_default(): @@ -66,65 +133,49 @@ class L2hmcTest(tf.test.TestCase): n_steps=hparams.n_steps, eps=hparams.eps) x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim]) - loss, x_out = l2hmc.compute_loss(x, dynamics) + x_, v_, x_accept_prob, x_out = dynamics.apply_transition(x) samples = npr.normal(size=[hparams.n_samples, hparams.x_dim]) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) - loss_np, x_out_np = sess.run([loss, x_out], feed_dict={x: samples}) + np_x_, np_v_, np_x_accept_prob, np_x_out = sess.run( + [x_, v_, x_accept_prob, x_out], feed_dict={x: samples}) - # Check shape and numerical stability - self.assertEqual(x_out_np.shape, samples.shape) - self.assertEqual(loss_np.shape, ()) - self.assertAllClose(loss_np, loss_np, rtol=1e-5) + self.assertEqual(np_x_.shape, np_v_.shape) + self.assertEqual(samples.shape, np_x_out.shape) + self.assertEqual(np_x_.shape, np_x_out.shape) + self.assertEqual(np_x_accept_prob.shape, (hparams.n_samples,)) class L2hmcBenchmark(tf.test.Benchmark): """Eager and graph benchmarks for l2hmc.""" - def benchmarkEagerL2hmc(self): - """Benchmark Eager performance.""" - - hparams = get_default_hparams() - dynamics = l2hmc.Dynamics( - x_dim=hparams.x_dim, - loglikelihood_fn=l2hmc.get_scg_energy_fn(), - n_steps=hparams.n_steps, - eps=hparams.eps) - # TODO(lxuechen): Add learning rate decay - optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate) - - # Warmup to reduce initialization effect when timing - l2hmc.warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters) + def _get_energy_fn(self): + """Get specific energy function according to FLAGS.""" - # Time - start_time = time.time() - l2hmc.fit( - dynamics, - optimizer, - n_samples=hparams.n_samples, - n_iters=hparams.n_iters) - wall_time = time.time() - start_time - examples_per_sec = hparams.n_samples / wall_time + if FLAGS.energy_fn == "scg": + energy_fn = l2hmc.get_scg_energy_fn() + elif FLAGS.energy_fn == "multivariate_gaussian": + energy_fn = l2hmc.get_multivariate_gaussian_energy_fn(x_dim=FLAGS.x_dim) + else: + raise ValueError("No such energy function %s" % FLAGS.energy_fn) - self.report_benchmark( - name="eager_train_%s" % ("gpu" if tfe.num_gpus() > 0 else "cpu"), - iters=hparams.n_iters, - extras={"examples_per_sec": examples_per_sec}, - wall_time=wall_time) + return energy_fn - def benchmarkGraphL2hmc(self): + def benchmark_graph(self): """Benchmark Graph performance.""" hparams = get_default_hparams() + tf.reset_default_graph() with tf.Graph().as_default(): + energy_fn = self._get_energy_fn() dynamics = l2hmc.Dynamics( x_dim=hparams.x_dim, - loglikelihood_fn=l2hmc.get_scg_energy_fn(), + loglikelihood_fn=energy_fn, n_steps=hparams.n_steps, eps=hparams.eps) x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim]) - loss, x_out = l2hmc.compute_loss(x, dynamics) + loss, x_out = compute_loss(dynamics, x) global_step = tf.Variable(0., name="global_step", trainable=False) learning_rate = tf.train.exponential_decay( @@ -138,14 +189,15 @@ class L2hmcBenchmark(tf.test.Benchmark): # Warmup to reduce initialization effect when timing samples = npr.normal(size=[hparams.n_samples, hparams.x_dim]) for _ in range(hparams.n_warmup_iters): - samples, _, _, _ = sess.run( + _, _, _, _ = sess.run( [x_out, loss, train_op, learning_rate], feed_dict={x: samples}) - # Time + # Training start_time = time.time() - for _ in range(hparams.n_iters): - samples, _, _, _ = sess.run( + for i in range(hparams.n_iters): + samples, loss_np, _, _ = sess.run( [x_out, loss, train_op, learning_rate], feed_dict={x: samples}) + print("Iteration %d: loss %.4f" % (i, loss_np)) wall_time = time.time() - start_time examples_per_sec = hparams.n_samples / wall_time @@ -156,7 +208,57 @@ class L2hmcBenchmark(tf.test.Benchmark): extras={"examples_per_sec": examples_per_sec}, wall_time=wall_time) + def benchmark_eager(self): + self._benchmark_eager() + + def benchmark_eager_defun(self): + self._benchmark_eager(defun=True) + + def _benchmark_eager(self, defun=False): + """Benchmark Eager performance.""" + + hparams = get_default_hparams() + energy_fn = self._get_energy_fn() + dynamics = l2hmc.Dynamics( + x_dim=hparams.x_dim, + loglikelihood_fn=energy_fn, + n_steps=hparams.n_steps, + eps=hparams.eps) + optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate) + loss_fn = tfe.defun(compute_loss) if defun else compute_loss + + # Warmup to reduce initialization effect when timing + warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters, loss_fn=loss_fn) + + # Training + samples = tf.random_normal( + shape=[hparams.n_samples, hparams.x_dim], dtype=tf.float32) + start_time = time.time() + fit(dynamics, + samples, + optimizer, + loss_fn=loss_fn, + n_iters=hparams.n_iters, + decay_lr=True) + wall_time = time.time() - start_time + examples_per_sec = hparams.n_samples / wall_time + + self.report_benchmark( + name="eager_train_%s%s" % ("gpu" if tf.test.is_gpu_available() else + "cpu", "_defun" if defun else ""), + iters=hparams.n_iters, + extras={"examples_per_sec": examples_per_sec}, + wall_time=wall_time) + + del dynamics + del loss_fn + if __name__ == "__main__": + tf.flags.DEFINE_string("energy_fn", "scg", + ("The energy function/unnormalized log-probability. " + "Either be `scg` or `multivariate_gaussian`")) + tf.flags.DEFINE_integer("x_dim", 2, "Dimensionality of observation space.") + FLAGS = tf.flags.FLAGS tf.enable_eager_execution() tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py index c902e1f1f4862d704149fd4794f2a65ab8709640..e230ad5e259df5b450897bd815e901e3934cd293 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py +++ b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py @@ -57,8 +57,6 @@ class GenericNet(tf.keras.Model): initial_value=tf.zeros([1, x_dim]), name='coeff_transformation', trainable=True) - # TODO(lxuechen): Remove this after model.add_weight is in place - self.vars_not_in_layers = [self.coeff_scale, self.coeff_transformation] def call(self, inputs): v, x, t = inputs diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..54ebcad8e929c3195099121a290dd7c0651e5c9f --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb @@ -0,0 +1,909 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "nmt_with_attention.ipynb", + "version": "0.3.2", + "views": {}, + "default_view": {}, + "provenance": [ + { + "file_id": "1C4fpM7_7IL8ZzF7Gc5abywqQjeQNS2-U", + "timestamp": 1527858391290 + }, + { + "file_id": "1pExo6aUuw0S6MISFWoinfJv0Ftm9V4qv", + "timestamp": 1527776041613 + } + ], + "private_outputs": true, + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "metadata": { + "id": "AOpGoE2T-YXS", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "##### Copyright 2018 The TensorFlow Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\").\n", + "\n", + "# Neural Machine Translation with Attention\n", + "\n", + "
\n", + "\n", + " Run in Google Colab \n", + "\n", + "View source on Github
" + ] + }, + { + "metadata": { + "id": "CiwtNgENbx2g", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "This notebook trains a sequence to sequence (seq2seq) model for Spanish to English translation using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). This is an advanced example that assumes some knowledge of sequence to sequence models.\n", + "\n", + "After training the model in this notebook, you will be able to input a Spanish sentence, such as *\"¿todavia estan en casa?\"*, and return the English translation: *\"are you still at home?\"*\n", + "\n", + "The translation quality is reasonable for a toy example, but the generated attention plot is perhaps more interesting. This shows which parts of the input sentence has the model's attention while translating:\n", + "\n", + "\"spanish-english\n", + "\n", + "Note: This example takes approximately 10 mintues to run on a single P100 GPU." + ] + }, + { + "metadata": { + "id": "tnxXKDjq3jEL", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "from __future__ import absolute_import, division, print_function\n", + "\n", + "# Import TensorFlow >= 1.9 and enable eager execution\n", + "import tensorflow as tf\n", + "\n", + "tf.enable_eager_execution()\n", + "\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "import unicodedata\n", + "import re\n", + "import numpy as np\n", + "import os\n", + "import time\n", + "\n", + "print(tf.__version__)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "wfodePkj3jEa", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Download and prepare the dataset\n", + "\n", + "We'll use a language dataset provided by http://www.manythings.org/anki/. This dataset contains language translation pairs in the format:\n", + "\n", + "```\n", + "May I borrow this book?\t¿Puedo tomar prestado este libro?\n", + "```\n", + "\n", + "There are a variety of languages available, but we'll use the English-Spanish dataset. For convenience, we've hosted a copy of this dataset on Google Cloud, but you can also download your own copy. After downloading the dataset, here are the steps we'll take to prepare the data:\n", + "\n", + "1. Add a *start* and *end* token to each sentence.\n", + "2. Clean the sentences by removing special characters.\n", + "3. Create a word index and reverse word index (dictionaries mapping from word → id and id → word).\n", + "4. Pad each sentence to a maximum length." + ] + }, + { + "metadata": { + "id": "kRVATYOgJs1b", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# Download the file\n", + "path_to_zip = tf.keras.utils.get_file(\n", + " 'spa-eng.zip', origin='http://download.tensorflow.org/data/spa-eng.zip', \n", + " extract=True)\n", + "\n", + "path_to_file = os.path.dirname(path_to_zip)+\"/spa-eng/spa.txt\"" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "rd0jw-eC3jEh", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# Converts the unicode file to ascii\n", + "def unicode_to_ascii(s):\n", + " return ''.join(c for c in unicodedata.normalize('NFD', s)\n", + " if unicodedata.category(c) != 'Mn')\n", + "\n", + "\n", + "def preprocess_sentence(w):\n", + " w = unicode_to_ascii(w.lower().strip())\n", + " \n", + " # creating a space between a word and the punctuation following it\n", + " # eg: \"he is a boy.\" => \"he is a boy .\" \n", + " # Reference:- https://stackoverflow.com/questions/3645931/python-padding-punctuation-with-white-spaces-keeping-punctuation\n", + " w = re.sub(r\"([?.!,¿])\", r\" \\1 \", w)\n", + " w = re.sub(r'[\" \"]+', \" \", w)\n", + " \n", + " # replacing everything with space except (a-z, A-Z, \".\", \"?\", \"!\", \",\")\n", + " w = re.sub(r\"[^a-zA-Z?.!,¿]+\", \" \", w)\n", + " \n", + " w = w.rstrip().strip()\n", + " \n", + " # adding a start and an end token to the sentence\n", + " # so that the model know when to start and stop predicting.\n", + " w = ' ' + w + ' '\n", + " return w" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "OHn4Dct23jEm", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# 1. Remove the accents\n", + "# 2. Clean the sentences\n", + "# 3. Return word pairs in the format: [ENGLISH, SPANISH]\n", + "def create_dataset(path, num_examples):\n", + " lines = open(path, encoding='UTF-8').read().strip().split('\\n')\n", + " \n", + " word_pairs = [[preprocess_sentence(w) for w in l.split('\\t')] for l in lines[:num_examples]]\n", + " \n", + " return word_pairs" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "9xbqO7Iie9bb", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# This class creates a word -> index mapping (e.g,. \"dad\" -> 5) and vice-versa \n", + "# (e.g., 5 -> \"dad\") for each language,\n", + "class LanguageIndex():\n", + " def __init__(self, lang):\n", + " self.lang = lang\n", + " self.word2idx = {}\n", + " self.idx2word = {}\n", + " self.vocab = set()\n", + " \n", + " self.create_index()\n", + " \n", + " def create_index(self):\n", + " for phrase in self.lang:\n", + " self.vocab.update(phrase.split(' '))\n", + " \n", + " self.vocab = sorted(self.vocab)\n", + " \n", + " self.word2idx[''] = 0\n", + " for index, word in enumerate(self.vocab):\n", + " self.word2idx[word] = index + 1\n", + " \n", + " for word, index in self.word2idx.items():\n", + " self.idx2word[index] = word" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "eAY9k49G3jE_", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def max_length(tensor):\n", + " return max(len(t) for t in tensor)\n", + "\n", + "\n", + "def load_dataset(path, num_examples):\n", + " # creating cleaned input, output pairs\n", + " pairs = create_dataset(path, num_examples)\n", + "\n", + " # index language using the class defined above \n", + " inp_lang = LanguageIndex(sp for en, sp in pairs)\n", + " targ_lang = LanguageIndex(en for en, sp in pairs)\n", + " \n", + " # Vectorize the input and target languages\n", + " \n", + " # Spanish sentences\n", + " input_tensor = [[inp_lang.word2idx[s] for s in sp.split(' ')] for en, sp in pairs]\n", + " \n", + " # English sentences\n", + " target_tensor = [[targ_lang.word2idx[s] for s in en.split(' ')] for en, sp in pairs]\n", + " \n", + " # Calculate max_length of input and output tensor\n", + " # Here, we'll set those to the longest sentence in the dataset\n", + " max_length_inp, max_length_tar = max_length(input_tensor), max_length(target_tensor)\n", + " \n", + " # Padding the input and output tensor to the maximum length\n", + " input_tensor = tf.keras.preprocessing.sequence.pad_sequences(input_tensor, \n", + " maxlen=max_length_inp,\n", + " padding='post')\n", + " \n", + " target_tensor = tf.keras.preprocessing.sequence.pad_sequences(target_tensor, \n", + " maxlen=max_length_tar, \n", + " padding='post')\n", + " \n", + " return input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_tar" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "GOi42V79Ydlr", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Limit the size of the dataset to experiment faster (optional)\n", + "\n", + "Training on the complete dataset of >100,000 sentences will take a long time. To train faster, we can limit the size of the dataset to 30,000 sentences (of course, translation quality degrades with less data):" + ] + }, + { + "metadata": { + "id": "cnxC7q-j3jFD", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# Try experimenting with the size of that dataset\n", + "num_examples = 30000\n", + "input_tensor, target_tensor, inp_lang, targ_lang, max_length_inp, max_length_targ = load_dataset(path_to_file, num_examples)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "4QILQkOs3jFG", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# Creating training and validation sets using an 80-20 split\n", + "input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.2)\n", + "\n", + "# Show length\n", + "len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "rgCLkfv5uO3d", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "### Create a tf.data dataset" + ] + }, + { + "metadata": { + "id": "TqHsArVZ3jFS", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "BUFFER_SIZE = len(input_tensor_train)\n", + "BATCH_SIZE = 64\n", + "embedding_dim = 256\n", + "units = 1024\n", + "vocab_inp_size = len(inp_lang.word2idx)\n", + "vocab_tar_size = len(targ_lang.word2idx)\n", + "\n", + "dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n", + "dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "TNfHIF71ulLu", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Write the encoder and decoder model\n", + "\n", + "Here, we'll implement an encoder-decoder model with attention which you can read about in the TensorFlow [Neural Machine Translation (seq2seq) tutorial](https://www.tensorflow.org/tutorials/seq2seq). This example uses a more recent set of APIs. This notebook implements the [attention equations](https://www.tensorflow.org/tutorials/seq2seq#background_on_the_attention_mechanism) from the seq2seq tutorial. The following diagram shows that each input words is assigned a weight by the attention mechanism which is then used by the decoder to predict the next word in the sentence.\n", + "\n", + "\"attention\n", + "\n", + "The input is put through an encoder model which gives us the encoder output of shape *(batch_size, max_length, hidden_size)* and the encoder hidden state of shape *(batch_size, hidden_size)*. \n", + "\n", + "Here are the equations that are implemented:\n", + "\n", + "\"attention\n", + "\"attention\n", + "\n", + "We're using *Bahdanau attention*. Lets decide on notation before writing the simplified form:\n", + "\n", + "* FC = Fully connected (dense) layer\n", + "* EO = Encoder output\n", + "* H = hidden state\n", + "* X = input to the decoder\n", + "\n", + "And the pseudo-code:\n", + "\n", + "* `score = FC(tanh(FC(EO) + FC(H)))`\n", + "* `attention weights = softmax(score, axis = 1)`. Softmax by default is applied on the last axis but here we want to apply it on the *1st axis*, since the shape of score is *(batch_size, max_length, hidden_size)*. `Max_length` is the length of our input. Since we are trying to assign a weight to each input, softmax should be applied on that axis.\n", + "* `context vector = sum(attention weights * EO, axis = 1)`. Same reason as above for choosing axis as 1.\n", + "* `embedding output` = The input to the decoder X is passed through an embedding layer.\n", + "* `merged vector = concat(embedding output, context vector)`\n", + "* This merged vector is then given to the GRU\n", + " \n", + "The shapes of all the vectors at each step have been specified in the comments in the code:" + ] + }, + { + "metadata": { + "id": "avyJ_4VIUoHb", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def gru(units):\n", + " # If you have a GPU, we recommend using CuDNNGRU(provides a 3x speedup than GRU)\n", + " # the code automatically does that.\n", + " if tf.test.is_gpu_available():\n", + " return tf.keras.layers.CuDNNGRU(units, \n", + " return_sequences=True, \n", + " return_state=True, \n", + " recurrent_initializer='glorot_uniform')\n", + " else:\n", + " return tf.keras.layers.GRU(units, \n", + " return_sequences=True, \n", + " return_state=True, \n", + " recurrent_activation='sigmoid', \n", + " recurrent_initializer='glorot_uniform')" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "nZ2rI24i3jFg", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "class Encoder(tf.keras.Model):\n", + " def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):\n", + " super(Encoder, self).__init__()\n", + " self.batch_sz = batch_sz\n", + " self.enc_units = enc_units\n", + " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", + " self.gru = gru(self.enc_units)\n", + " \n", + " def call(self, x, hidden):\n", + " x = self.embedding(x)\n", + " output, state = self.gru(x, initial_state = hidden) \n", + " return output, state\n", + " \n", + " def initialize_hidden_state(self):\n", + " return tf.zeros((self.batch_sz, self.enc_units))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "yJ_B3mhW3jFk", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "class Decoder(tf.keras.Model):\n", + " def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n", + " super(Decoder, self).__init__()\n", + " self.batch_sz = batch_sz\n", + " self.dec_units = dec_units\n", + " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", + " self.gru = gru(self.dec_units)\n", + " self.fc = tf.keras.layers.Dense(vocab_size)\n", + " \n", + " # used for attention\n", + " self.W1 = tf.keras.layers.Dense(self.dec_units)\n", + " self.W2 = tf.keras.layers.Dense(self.dec_units)\n", + " self.V = tf.keras.layers.Dense(1)\n", + " \n", + " def call(self, x, hidden, enc_output):\n", + " # enc_output shape == (batch_size, max_length, hidden_size)\n", + " \n", + " # hidden shape == (batch_size, hidden size)\n", + " # hidden_with_time_axis shape == (batch_size, 1, hidden size)\n", + " # we are doing this to perform addition to calculate the score\n", + " hidden_with_time_axis = tf.expand_dims(hidden, 1)\n", + " \n", + " # score shape == (batch_size, max_length, hidden_size)\n", + " score = tf.nn.tanh(self.W1(enc_output) + self.W2(hidden_with_time_axis))\n", + " \n", + " # attention_weights shape == (batch_size, max_length, 1)\n", + " # we get 1 at the last axis because we are applying score to self.V\n", + " attention_weights = tf.nn.softmax(self.V(score), axis=1)\n", + " \n", + " # context_vector shape after sum == (batch_size, hidden_size)\n", + " context_vector = attention_weights * enc_output\n", + " context_vector = tf.reduce_sum(context_vector, axis=1)\n", + " \n", + " # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n", + " x = self.embedding(x)\n", + " \n", + " # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n", + " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n", + " \n", + " # passing the concatenated vector to the GRU\n", + " output, state = self.gru(x)\n", + " \n", + " # output shape == (batch_size * max_length, hidden_size)\n", + " output = tf.reshape(output, (-1, output.shape[2]))\n", + " \n", + " # output shape == (batch_size * max_length, vocab)\n", + " x = self.fc(output)\n", + " \n", + " return x, state, attention_weights\n", + " \n", + " def initialize_hidden_state(self):\n", + " return tf.zeros((self.batch_sz, self.dec_units))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "P5UY8wko3jFp", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n", + "decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "_ch_71VbIRfK", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Define the optimizer and the loss function" + ] + }, + { + "metadata": { + "id": "WmTHr5iV3jFr", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "optimizer = tf.train.AdamOptimizer()\n", + "\n", + "\n", + "def loss_function(real, pred):\n", + " mask = 1 - np.equal(real, 0)\n", + " loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n", + " return tf.reduce_mean(loss_)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "hpObfY22IddU", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Training\n", + "\n", + "1. Pass the *input* through the *encoder* which return *encoder output* and the *encoder hidden state*.\n", + "2. The encoder output, encoder hidden state and the decoder input (which is the *start token*) is passed to the decoder.\n", + "3. The decoder returns the *predictions* and the *decoder hidden state*.\n", + "4. The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n", + "5. Use *teacher forcing* to decide the next input to the decoder.\n", + "6. *Teacher forcing* is the technique where the *target word* is passed as the *next input* to the decoder.\n", + "7. The final step is to calculate the gradients and apply it to the optimizer and backpropagate." + ] + }, + { + "metadata": { + "id": "ddefjBMa3jF0", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "EPOCHS = 10\n", + "\n", + "for epoch in range(EPOCHS):\n", + " start = time.time()\n", + " \n", + " hidden = encoder.initialize_hidden_state()\n", + " total_loss = 0\n", + " \n", + " for (batch, (inp, targ)) in enumerate(dataset):\n", + " loss = 0\n", + " \n", + " with tf.GradientTape() as tape:\n", + " enc_output, enc_hidden = encoder(inp, hidden)\n", + " \n", + " dec_hidden = enc_hidden\n", + " \n", + " dec_input = tf.expand_dims([targ_lang.word2idx['']] * BATCH_SIZE, 1) \n", + " \n", + " # Teacher forcing - feeding the target as the next input\n", + " for t in range(1, targ.shape[1]):\n", + " # passing enc_output to the decoder\n", + " predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)\n", + " \n", + " loss += loss_function(targ[:, t], predictions)\n", + " \n", + " # using teacher forcing\n", + " dec_input = tf.expand_dims(targ[:, t], 1)\n", + " \n", + " total_loss += (loss / int(targ.shape[1]))\n", + " \n", + " variables = encoder.variables + decoder.variables\n", + " \n", + " gradients = tape.gradient(loss, variables)\n", + " \n", + " optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n", + "\n", + " if batch % 100 == 0:\n", + " print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,\n", + " batch,\n", + " loss.numpy() / int(targ.shape[1])))\n", + " \n", + " print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n", + " total_loss/len(input_tensor)))\n", + " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "mU3Ce8M6I3rz", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Translate\n", + "\n", + "* The evaluate function is similar to the training loop, except we don't use *teacher forcing* here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n", + "* Stop predicting when the model predicts the *end token*.\n", + "* And store the *attention weights for every time step*.\n", + "\n", + "Note: The encoder output is calculated only once for one input." + ] + }, + { + "metadata": { + "id": "EbQpyYs13jF_", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n", + " attention_plot = np.zeros((max_length_targ, max_length_inp))\n", + " \n", + " sentence = preprocess_sentence(sentence)\n", + "\n", + " inputs = [inp_lang.word2idx[i] for i in sentence.split(' ')]\n", + " inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs], maxlen=max_length_inp, padding='post')\n", + " inputs = tf.convert_to_tensor(inputs)\n", + " \n", + " result = ''\n", + "\n", + " hidden = [tf.zeros((1, units))]\n", + " enc_out, enc_hidden = encoder(inputs, hidden)\n", + "\n", + " dec_hidden = enc_hidden\n", + " dec_input = tf.expand_dims([targ_lang.word2idx['']], 0)\n", + "\n", + " for t in range(max_length_targ):\n", + " predictions, dec_hidden, attention_weights = decoder(dec_input, dec_hidden, enc_out)\n", + " \n", + " # storing the attention weigths to plot later on\n", + " attention_weights = tf.reshape(attention_weights, (-1, ))\n", + " attention_plot[t] = attention_weights.numpy()\n", + "\n", + " predicted_id = tf.multinomial(tf.exp(predictions), num_samples=1)[0][0].numpy()\n", + "\n", + " result += targ_lang.idx2word[predicted_id] + ' '\n", + "\n", + " if targ_lang.idx2word[predicted_id] == '':\n", + " return result, sentence, attention_plot\n", + " \n", + " # the predicted ID is fed back into the model\n", + " dec_input = tf.expand_dims([predicted_id], 0)\n", + "\n", + " return result, sentence, attention_plot" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "s5hQWlbN3jGF", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# function for plotting the attention weights\n", + "def plot_attention(attention, sentence, predicted_sentence):\n", + " fig = plt.figure(figsize=(10,10))\n", + " ax = fig.add_subplot(1, 1, 1)\n", + " ax.matshow(attention, cmap='viridis')\n", + " \n", + " fontdict = {'fontsize': 14}\n", + " \n", + " ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)\n", + " ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)\n", + "\n", + " plt.show()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "sl9zUHzg3jGI", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "def translate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):\n", + " result, sentence, attention_plot = evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)\n", + " \n", + " print('Input: {}'.format(sentence))\n", + " print('Predicted translation: {}'.format(result))\n", + " \n", + " attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]\n", + " plot_attention(attention_plot, sentence.split(' '), result.split(' '))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "WrAM0FDomq3E", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "translate('hace mucho frio aqui.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "zSx2iM36EZQZ", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "translate('esta es mi vida.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "A3LLCx3ZE0Ls", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "translate('¿todavia estan en casa?', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "DUQVLVqUE1YW", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "cell_type": "code", + "source": [ + "# wrong translation\n", + "translate('trata de averiguarlo.', encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "id": "RTe5P5ioMJwN", + "colab_type": "text" + }, + "cell_type": "markdown", + "source": [ + "## Next steps\n", + "\n", + "* [Download a different dataset](http://www.manythings.org/anki/) to experiment with translations, for example, English to German, or English to French.\n", + "* Experiment with training on a larger dataset, or using more epochs\n" + ] + } + ] +} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb index 4fe3a0e3f3d431684973a9251aa3d92bf2010444..5749f22ac58e0a012ed7e3fec4dfe2913d3f8273 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb @@ -68,7 +68,7 @@ "# simply construct the object. Most layers take as a first argument the number\n", "# of output dimensions / channels.\n", "layer = tf.keras.layers.Dense(100)\n", - "# The number of input dimensionss is often unnecessary, as it can be inferred\n", + "# The number of input dimensions is often unnecessary, as it can be inferred\n", "# the first time the layer is used, but it can be provided if you want to \n", "# specify it manually, which is useful in some complex models.\n", "layer = tf.keras.layers.Dense(10, input_shape=(None, 5))" @@ -267,7 +267,7 @@ " * `build`, where you know the shapes of the input tensors and can do the rest of the initialization\n", " * `call`, where you do the forward computation\n", "\n", - "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`. However, the advantage of creating them in `build` is that it enables late variable creation based on the shape of the inputs the layer will operate on. On the other hand, creating variables in `__init__` would mean that shapes requires to create the variables will need to be explicitly specified." + "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`. However, the advantage of creating them in `build` is that it enables late variable creation based on the shape of the inputs the layer will operate on. On the other hand, creating variables in `__init__` would mean that shapes required to create the variables will need to be explicitly specified." ] }, { diff --git a/tensorflow/contrib/eager/python/examples/resnet50/BUILD b/tensorflow/contrib/eager/python/examples/resnet50/BUILD index 0c0e28dd95c68dc300384a128eb5aa2208f63a0d..68a84d5fbb4f13e4ebe0d71e3f5caebe97e2101c 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/BUILD +++ b/tensorflow/contrib/eager/python/examples/resnet50/BUILD @@ -51,5 +51,6 @@ cuda_py_test( "noasan", "nomsan", "notsan", + "optonly", ], ) diff --git a/tensorflow/contrib/eager/python/examples/revnet/BUILD b/tensorflow/contrib/eager/python/examples/revnet/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..432bb546f83932d0e0a465d7af7c641b60d2e564 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/BUILD @@ -0,0 +1,114 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +# Model +py_library( + name = "ops", + srcs = ["ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "config", + srcs = ["config.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "blocks", + srcs = ["blocks.py"], + srcs_version = "PY2AND3", + deps = [ + ":ops", + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "revnet", + srcs = ["revnet.py"], + srcs_version = "PY2AND3", + deps = [ + ":blocks", + "//tensorflow:tensorflow_py", + ], +) + +# Tests +cuda_py_test( + name = "ops_test", + size = "large", + srcs = ["ops_test.py"], + additional_deps = [ + ":ops", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_test( + name = "blocks_test", + size = "large", + srcs = ["blocks_test.py"], + additional_deps = [ + ":blocks", + "//tensorflow:tensorflow_py", + ], + tags = [ + "optonly", + ], +) + +cuda_py_test( + name = "revnet_test", + size = "large", + srcs = ["revnet_test.py"], + additional_deps = [ + ":config", + ":revnet", + "//tensorflow:tensorflow_py", + ], + tags = [ + "optonly", + ], +) + +# Training +py_library( + name = "cifar_input", + srcs = ["cifar_input.py"], + srcs_version = "PY2AND3", + deps = [ + ":revnet", + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "cifar_tfrecords", + srcs = ["cifar_tfrecords.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "main", + srcs = ["main.py"], + srcs_version = "PY2AND3", + deps = [ + ":cifar_input", + ":config", + ":revnet", + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks.py b/tensorflow/contrib/eager/python/examples/revnet/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..af41f6428660dd6b80e1a28f7e70021fe260a9b5 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/blocks.py @@ -0,0 +1,335 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reversible residual network compatible with eager execution. + +Building blocks with manual backward gradient computation. + +Reference [The Reversible Residual Network: Backpropagation +Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import ops + + +class RevBlock(tf.keras.Model): + """Single reversible block containing several `_Residual` blocks. + + Each `_Residual` block in turn contains two _ResidualInner blocks, + corresponding to the `F`/`G` functions in the paper. + """ + + def __init__(self, + n_res, + filters, + strides, + input_shape, + batch_norm_first=False, + data_format="channels_first", + bottleneck=False, + fused=True): + """Initialize RevBlock. + + Args: + n_res: number of residual blocks + filters: list/tuple of integers for output filter sizes of each residual + strides: length 2 list/tuple of integers for height and width strides + input_shape: length 3 list/tuple of integers + batch_norm_first: whether to apply activation and batch norm before conv + data_format: tensor data format, "NCHW"/"NHWC" + bottleneck: use bottleneck residual if True + fused: use fused batch normalization if True + """ + super(RevBlock, self).__init__() + self.blocks = tf.contrib.checkpoint.List() + for i in range(n_res): + curr_batch_norm_first = batch_norm_first and i == 0 + curr_strides = strides if i == 0 else (1, 1) + block = _Residual( + filters, + curr_strides, + input_shape, + batch_norm_first=curr_batch_norm_first, + data_format=data_format, + bottleneck=bottleneck, + fused=fused) + self.blocks.append(block) + + if data_format == "channels_first": + input_shape = (filters, input_shape[1] // curr_strides[0], + input_shape[2] // curr_strides[1]) + else: + input_shape = (input_shape[0] // curr_strides[0], + input_shape[1] // curr_strides[1], filters) + + def call(self, h, training=True): + """Apply reversible block to inputs.""" + + for block in self.blocks: + h = block(h, training=training) + return h + + def backward_grads_and_vars(self, x, y, dy, training=True): + """Apply reversible block backward to outputs.""" + + grads_all = [] + vars_all = [] + + for i in reversed(range(len(self.blocks))): + block = self.blocks[i] + y_inv = x if i == 0 else block.backward(y, training=training) + dy, grads, vars_ = block.backward_grads_and_vars( + y_inv, dy, training=training) + grads_all += grads + vars_all += vars_ + + return dy, grads_all, vars_all + + +class _Residual(tf.keras.Model): + """Single residual block contained in a _RevBlock. Each `_Residual` object has + two _ResidualInner objects, corresponding to the `F` and `G` functions in the + paper. + + Args: + filters: output filter size + strides: length 2 list/tuple of integers for height and width strides + input_shape: length 3 list/tuple of integers + batch_norm_first: whether to apply activation and batch norm before conv + data_format: tensor data format, "NCHW"/"NHWC", + bottleneck: use bottleneck residual if True + fused: use fused batch normalization if True + """ + + def __init__(self, + filters, + strides, + input_shape, + batch_norm_first=True, + data_format="channels_first", + bottleneck=False, + fused=True): + super(_Residual, self).__init__() + + self.filters = filters + self.strides = strides + self.axis = 1 if data_format == "channels_first" else 3 + if data_format == "channels_first": + f_input_shape = (input_shape[0] // 2,) + input_shape[1:] + g_input_shape = (filters // 2, input_shape[1] // strides[0], + input_shape[2] // strides[1]) + else: + f_input_shape = input_shape[:2] + (input_shape[2] // 2,) + g_input_shape = (input_shape[0] // strides[0], + input_shape[1] // strides[1], filters // 2) + + factory = _BottleneckResidualInner if bottleneck else _ResidualInner + self.f = factory( + filters=filters // 2, + strides=strides, + input_shape=f_input_shape, + batch_norm_first=batch_norm_first, + data_format=data_format, + fused=fused) + self.g = factory( + filters=filters // 2, + strides=(1, 1), + input_shape=g_input_shape, + batch_norm_first=batch_norm_first, + data_format=data_format, + fused=fused) + + def call(self, x, training=True, concat=True): + """Apply residual block to inputs.""" + + x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis) + f_x2 = self.f.call(x2, training=training) + # TODO(lxuechen): Replace with simpler downsampling + x1_down = ops.downsample( + x1, self.filters // 2, self.strides, axis=self.axis) + x2_down = ops.downsample( + x2, self.filters // 2, self.strides, axis=self.axis) + y1 = f_x2 + x1_down + g_y1 = self.g.call(y1, training=training) # self.g(y1) gives pylint error + y2 = g_y1 + x2_down + if not concat: # Concat option needed for correct backward grads + return y1, y2 + return tf.concat([y1, y2], axis=self.axis) + + def backward(self, y, training=True): + """Reconstruct inputs from outputs; only valid when stride 1.""" + + assert self.strides == (1, 1) + + y1, y2 = tf.split(y, num_or_size_splits=2, axis=self.axis) + g_y1 = self.g.call(y1, training=training) + x2 = y2 - g_y1 + f_x2 = self.f.call(x2, training=training) + x1 = y1 - f_x2 + + return tf.concat([x1, x2], axis=self.axis) + + def backward_grads_and_vars(self, x, dy, training=True): + """Manually compute backward gradients given input and output grads.""" + + with tf.GradientTape(persistent=True) as tape: + x = tf.identity(x) # TODO(lxuechen): Remove after b/110264016 is fixed + x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis) + tape.watch([x1, x2]) + # Stitch back x for `call` so tape records correct grads + x = tf.concat([x1, x2], axis=self.axis) + dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=self.axis) + y1, y2 = self.call(x, training=training, concat=False) + x2_down = ops.downsample( + x2, self.filters // 2, self.strides, axis=self.axis) + + grads_combined = tape.gradient( + y2, [y1] + self.g.trainable_variables, output_gradients=[dy2]) + dy2_y1, dg = grads_combined[0], grads_combined[1:] + dy1_plus = dy2_y1 + dy1 + + grads_combined = tape.gradient( + y1, [x1, x2] + self.f.trainable_variables, output_gradients=[dy1_plus]) + dx1, dx2, df = grads_combined[0], grads_combined[1], grads_combined[2:] + dx2 += tape.gradient(x2_down, [x2], output_gradients=[dy2])[0] + + del tape + + grads = df + dg + vars_ = self.f.trainable_variables + self.g.trainable_variables + + return tf.concat([dx1, dx2], axis=self.axis), grads, vars_ + + +def _BottleneckResidualInner(filters, + strides, + input_shape, + batch_norm_first=True, + data_format="channels_first", + fused=True): + """Single bottleneck residual inner function contained in _Resdual. + + Corresponds to the `F`/`G` functions in the paper. + Suitable for training on ImageNet dataset. + + Args: + filters: output filter size + strides: length 2 list/tuple of integers for height and width strides + input_shape: length 3 list/tuple of integers + batch_norm_first: whether to apply activation and batch norm before conv + data_format: tensor data format, "NCHW"/"NHWC" + fused: use fused batch normalization if True + + Returns: + A keras model + """ + + axis = 1 if data_format == "channels_first" else 3 + model = tf.keras.Sequential() + if batch_norm_first: + model.add( + tf.keras.layers.BatchNormalization( + axis=axis, input_shape=input_shape, fused=fused)) + model.add(tf.keras.layers.Activation("relu")) + model.add( + tf.keras.layers.Conv2D( + filters=filters // 4, + kernel_size=1, + strides=strides, + input_shape=input_shape, + data_format=data_format, + use_bias=False, + padding="SAME")) + + model.add(tf.keras.layers.BatchNormalization(axis=axis, fused=fused)) + model.add(tf.keras.layers.Activation("relu")) + model.add( + tf.keras.layers.Conv2D( + filters=filters // 4, + kernel_size=3, + strides=(1, 1), + data_format=data_format, + use_bias=False, + padding="SAME")) + + model.add(tf.keras.layers.BatchNormalization(axis=axis, fused=fused)) + model.add(tf.keras.layers.Activation("relu")) + model.add( + tf.keras.layers.Conv2D( + filters=filters, + kernel_size=1, + strides=(1, 1), + data_format=data_format, + use_bias=False, + padding="SAME")) + + return model + + +def _ResidualInner(filters, + strides, + input_shape, + batch_norm_first=True, + data_format="channels_first", + fused=True): + """Single residual inner function contained in _ResdualBlock. + + Corresponds to the `F`/`G` functions in the paper. + + Args: + filters: output filter size + strides: length 2 list/tuple of integers for height and width strides + input_shape: length 3 list/tuple of integers + batch_norm_first: whether to apply activation and batch norm before conv + data_format: tensor data format, "NCHW"/"NHWC" + fused: use fused batch normalization if True + + Returns: + A keras model + """ + + axis = 1 if data_format == "channels_first" else 3 + model = tf.keras.Sequential() + if batch_norm_first: + model.add( + tf.keras.layers.BatchNormalization( + axis=axis, input_shape=input_shape, fused=fused)) + model.add(tf.keras.layers.Activation("relu")) + model.add( + tf.keras.layers.Conv2D( + filters=filters, + kernel_size=3, + strides=strides, + input_shape=input_shape, + data_format=data_format, + use_bias=False, + padding="SAME")) + + model.add(tf.keras.layers.BatchNormalization(axis=axis, fused=fused)) + model.add(tf.keras.layers.Activation("relu")) + model.add( + tf.keras.layers.Conv2D( + filters=filters, + kernel_size=3, + strides=(1, 1), + data_format=data_format, + use_bias=False, + padding="SAME")) + + return model diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f4436fd92506d54f1206fbfd424b897f9835657d --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py @@ -0,0 +1,346 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for basic building blocks used in eager mode RevNet.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import blocks + + +def _validate_block_call_channels_last(block_factory, test): + """Generic testing function for `channels_last` data format. + + Completes a set of tests varying data format, stride, and batch normalization + configured train vs test time. + Args: + block_factory: constructor of one of blocks.InitBlock, blocks.FinalBlock, + blocks._ResidualInner + test: tf.test.TestCase object + """ + with tf.device("/cpu:0"): # NHWC format + input_shape = (224, 224, 32) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + + # Stride 1 + block = block_factory( + filters=64, + strides=(1, 1), + input_shape=input_shape, + data_format="channels_last") + y_tr, y_ev = block(x, training=True), block(x, training=False) + test.assertEqual(y_tr.shape, y_ev.shape) + test.assertEqual(y_ev.shape, (16, 224, 224, 64)) + test.assertNotAllClose(y_tr, y_ev) + + # Stride of 2 + block = block_factory( + filters=64, + strides=(2, 2), + input_shape=input_shape, + data_format="channels_last") + y_tr, y_ev = block(x, training=True), block(x, training=False) + test.assertEqual(y_tr.shape, y_ev.shape) + test.assertEqual(y_ev.shape, (16, 112, 112, 64)) + test.assertNotAllClose(y_tr, y_ev) + + +def _validate_block_call_channels_first(block_factory, test): + """Generic testing function for `channels_first` data format. + + Completes a set of tests varying data format, stride, and batch normalization + configured train vs test time. + Args: + block_factory: constructor of one of blocks.InitBlock, blocks.FinalBlock, + blocks._ResidualInner + test: tf.test.TestCase object + """ + if not tf.test.is_gpu_available(): + test.skipTest("GPU not available") + + with tf.device("/gpu:0"): # Default NCHW format + input_shape = (32, 224, 224) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + + # Stride of 1 + block = block_factory(filters=64, strides=(1, 1), input_shape=input_shape) + y_tr, y_ev = block(x, training=True), block(x, training=False) + test.assertEqual(y_tr.shape, y_ev.shape) + test.assertEqual(y_ev.shape, (16, 64, 224, 224)) + test.assertNotAllClose(y_tr, y_ev) + + # Stride of 2 + block = block_factory(filters=64, strides=(2, 2), input_shape=input_shape) + y_tr, y_ev = block(x, training=True), block(x, training=False) + test.assertEqual(y_tr.shape, y_ev.shape) + test.assertEqual(y_ev.shape, (16, 64, 112, 112)) + test.assertNotAllClose(y_tr, y_ev) + + +class RevBlockTest(tf.test.TestCase): + + def test_call_channels_first(self): + """Test `call` function with `channels_first` data format.""" + if not tf.test.is_gpu_available(): + self.skipTest("GPU not available") + + with tf.device("/gpu:0"): # Default NCHW format + input_shape = (32, 224, 224) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + + # Stride of 1 + block = blocks.RevBlock( + n_res=3, filters=64, strides=(1, 1), input_shape=input_shape) + y_tr, y_ev = block(x, training=True), block(x, training=False) + self.assertEqual(y_tr.shape, y_ev.shape) + self.assertEqual(y_ev.shape, (16, 64, 224, 224)) + self.assertNotAllClose(y_tr, y_ev) + + # Stride of 2 + block = blocks.RevBlock( + n_res=3, filters=64, strides=(2, 2), input_shape=input_shape) + y_tr, y_ev = block(x, training=True), block(x, training=False) + self.assertEqual(y_tr.shape, y_ev.shape) + self.assertEqual(y_ev.shape, [16, 64, 112, 112]) + self.assertNotAllClose(y_tr, y_ev) + + def test_call_channels_last(self): + """Test `call` function with `channels_last` data format.""" + with tf.device("/cpu:0"): # NHWC format + input_shape = (224, 224, 32) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + + # Stride 1 + block = blocks.RevBlock( + n_res=3, + filters=64, + strides=(1, 1), + input_shape=input_shape, + data_format="channels_last") + y_tr, y_ev = block(x, training=True), block(x, training=False) + self.assertEqual(y_tr.shape, y_ev.shape) + self.assertEqual(y_ev.shape, (16, 224, 224, 64)) + self.assertNotAllClose(y_tr, y_ev) + + # Stride of 2 + block = blocks.RevBlock( + n_res=3, + filters=64, + strides=(2, 2), + input_shape=input_shape, + data_format="channels_last") + y_tr, y_ev = block(x, training=True), block(x, training=False) + self.assertEqual(y_tr.shape, y_ev.shape) + self.assertEqual(y_ev.shape, (16, 112, 112, 64)) + self.assertNotAllClose(y_tr, y_ev) + + def test_backward_grads_and_vars_channels_first(self): + """Test `backward` function with `channels_first` data format.""" + if not tf.test.is_gpu_available(): + self.skipTest("GPU not available") + + with tf.device("/gpu:0"): # Default NCHW format + input_shape = (32, 224, 224) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + + # Stride 1 + y = tf.random_normal(shape=data_shape) + dy = tf.random_normal(shape=data_shape) + block = blocks.RevBlock( + n_res=3, filters=32, strides=(1, 1), input_shape=input_shape) + dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy) + self.assertEqual(dy.shape, x.shape) + self.assertTrue(isinstance(grads, list)) + self.assertTrue(isinstance(vars_, list)) + + # Stride 2 + y = tf.random_normal(shape=(16, 32, 112, 112)) + dy = tf.random_normal(shape=(16, 32, 112, 112)) + block = blocks.RevBlock( + n_res=3, filters=32, strides=(2, 2), input_shape=input_shape) + dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy) + self.assertEqual(dy.shape, x.shape) + self.assertTrue(isinstance(grads, list)) + self.assertTrue(isinstance(vars_, list)) + + def test_backward_grads_and_vars_channels_last(self): + """Test `backward` function with `channels_last` data format.""" + with tf.device("/cpu:0"): # NHWC format + input_shape = (224, 224, 32) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + + # Stride 1 + y = tf.random_normal(shape=data_shape) + dy = tf.random_normal(shape=data_shape) + block = blocks.RevBlock( + n_res=3, + filters=32, + strides=(1, 1), + input_shape=input_shape, + data_format="channels_last") + dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy) + self.assertEqual(dy.shape, x.shape) + self.assertTrue(isinstance(grads, list)) + self.assertTrue(isinstance(vars_, list)) + + # Stride 2 + y = tf.random_normal(shape=(16, 112, 112, 32)) + dy = tf.random_normal(shape=(16, 112, 112, 32)) + block = blocks.RevBlock( + n_res=3, + filters=32, + strides=(2, 2), + input_shape=input_shape, + data_format="channels_last") + dy, grads, vars_ = block.backward_grads_and_vars(x, y, dy) + self.assertEqual(dy.shape, x.shape) + self.assertTrue(isinstance(grads, list)) + self.assertTrue(isinstance(vars_, list)) + + +class _ResidualTest(tf.test.TestCase): + + def test_call(self): + """Test `call` function. + + Varying downsampling and data format options. + """ + + _validate_block_call_channels_first(blocks._Residual, self) + _validate_block_call_channels_last(blocks._Residual, self) + + def test_backward_channels_first(self): + """Test `backward` function with `channels_first` data format.""" + if not tf.test.is_gpu_available(): + self.skipTest("GPU not available") + + with tf.device("/gpu:0"): # Default NCHW format + input_shape = (16, 224, 224) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + residual = blocks._Residual( + filters=16, strides=(1, 1), input_shape=input_shape) + y_tr, y_ev = residual(x, training=True), residual(x, training=False) + x_ = residual.backward(y_tr, training=True) + # The numerical loss is alarming; reconstructed inputs could differ from + # the original inputs often by more than 1e-3 + self.assertAllClose(x, x_, rtol=1e-01, atol=1e-01) + x_ = residual.backward(y_ev, training=False) + self.assertAllClose(x, x_, rtol=1e-01, atol=1e-01) + + def test_backward_channels_last(self): + """Test `backward` function with `channels_last` data format.""" + with tf.device("/cpu:0"): # NHWC format + input_shape = (224, 224, 16) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + residual = blocks._Residual( + filters=16, + strides=(1, 1), + input_shape=input_shape, + data_format="channels_last") + y_tr, y_ev = residual(x, training=True), residual(x, training=False) + x_ = residual.backward(y_tr, training=True) + # Egregious numerical error + self.assertAllClose(x, x_, rtol=1e-01, atol=1e-01) + x_ = residual.backward(y_ev, training=False) + self.assertAllClose(x, x_, rtol=1e-01, atol=1e-01) + + def test_backward_grads_and_vars_channels_first(self): + """Test `backward_grads` function with `channels_first` data format.""" + if not tf.test.is_gpu_available(): + self.skipTest("GPU not available") + + with tf.device("/gpu:0"): # Default NCHW format + input_shape = (16, 224, 224) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + dy = tf.random_normal(shape=data_shape) + residual = blocks._Residual( + filters=16, strides=(1, 1), input_shape=input_shape) + dx_tr, grads_tr, vars_tr = residual.backward_grads_and_vars( + x, dy=dy, training=True) + dx_ev, grads_ev, vars_ev = residual.backward_grads_and_vars( + x, dy=dy, training=False) + self.assertNotAllClose(dx_tr, dx_ev) + self.assertTrue(isinstance(grads_tr, list)) + self.assertTrue(isinstance(grads_ev, list)) + self.assertTrue(isinstance(vars_tr, list)) + self.assertTrue(isinstance(vars_ev, list)) + for grad_tr, var_tr, grad_ev, var_ev in zip(grads_tr, vars_tr, grads_ev, + vars_ev): + if grad_tr is not None: # Batch norm moving mean, var gives None grad + self.assertEqual(grad_tr.shape, grad_ev.shape) + self.assertEqual(var_tr.shape, var_ev.shape) + self.assertEqual(grad_tr.shape, var_tr.shape) + + def test_backward_grads_and_vars_channels_last(self): + """Test `backward_grads` function with `channels_last` data format.""" + with tf.device("/cpu:0"): # NHWC format + input_shape = (224, 224, 16) + data_shape = (16,) + input_shape + x = tf.random_normal(shape=data_shape) + dy = tf.random_normal(shape=data_shape) + residual = blocks._Residual( + filters=16, + strides=(1, 1), + input_shape=input_shape, + data_format="channels_last") + dx_tr, grads_tr, vars_tr = residual.backward_grads_and_vars( + x, dy=dy, training=True) + dx_ev, grads_ev, vars_ev = residual.backward_grads_and_vars( + x, dy=dy, training=False) + self.assertNotAllClose(dx_tr, dx_ev) + self.assertTrue(isinstance(grads_tr, list)) + self.assertTrue(isinstance(grads_ev, list)) + self.assertTrue(isinstance(vars_tr, list)) + self.assertTrue(isinstance(vars_ev, list)) + for grad_tr, var_tr, grad_ev, var_ev in zip(grads_tr, vars_tr, grads_ev, + vars_ev): + if grad_tr is not None: # Batch norm moving mean, var gives None grad + self.assertEqual(grad_tr.shape, grad_ev.shape) + self.assertEqual(var_tr.shape, var_ev.shape) + self.assertEqual(grad_tr.shape, var_tr.shape) + + +class _ResidualInnerTest(tf.test.TestCase): + + def test_call(self): + """Test `call` function.""" + + _validate_block_call_channels_first(blocks._ResidualInner, self) + _validate_block_call_channels_last(blocks._ResidualInner, self) + + +class _BottleneckResidualInner(tf.test.TestCase): + + def test_call(self): + """Test `call` function.""" + + _validate_block_call_channels_first(blocks._BottleneckResidualInner, self) + _validate_block_call_channels_last(blocks._BottleneckResidualInner, self) + + +if __name__ == "__main__": + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py new file mode 100644 index 0000000000000000000000000000000000000000..3bc69da5adae29e6b6f43ef5045eb0256e680fa4 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py @@ -0,0 +1,105 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Script for reading and loading CIFAR-10.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import tensorflow as tf + +# Global constants describing the CIFAR data set. +IMAGE_HEIGHT = 32 +IMAGE_WIDTH = 32 +NUM_CHANNEL = 3 +NUM_TRAIN_IMG = 50000 +NUM_TEST_IMG = 10000 + + +def get_ds_from_tfrecords(data_dir, + split, + data_aug=True, + batch_size=100, + epochs=None, + shuffle=True, + data_format="channels_first", + num_parallel_calls=4, + prefetch=True, + div255=True, + dtype=tf.float32): + """Returns a tf.train.Dataset object from reading tfrecords. + + Args: + data_dir: Directory of tfrecords + split: "train", "validation", or "test" + data_aug: Apply data augmentation if True + batch_size: Batch size of dataset object + epochs: Number of epochs to repeat the dataset + shuffle: Shuffle the dataset if True + data_format: `channels_first` or `channels_last` + num_parallel_calls: Number of threads for dataset preprocess + prefetch: Apply prefetch for the dataset if True + div255: Divide the images by 255 if True + dtype: Data type of images + Returns: + A tf.train.Dataset object + + Raises: + ValueError: Unknown split + """ + + if split not in ["train", "validation", "test"]: + raise ValueError("Unknown split {}".format(split)) + + def _parser(serialized_example): + """Parses a single tf.Example into image and label tensors.""" + features = tf.parse_single_example( + serialized_example, + features={ + "image": tf.FixedLenFeature([], tf.string), + "label": tf.FixedLenFeature([], tf.int64), + }) + image = tf.decode_raw(features["image"], tf.uint8) + image = tf.reshape(image, [IMAGE_HEIGHT, IMAGE_WIDTH, NUM_CHANNEL]) + image = tf.cast(image, dtype) + label = tf.cast(features["label"], tf.int32) + + if data_aug: + image = tf.image.resize_image_with_crop_or_pad(image, IMAGE_HEIGHT + 4, + IMAGE_WIDTH + 4) + image = tf.random_crop(image, [IMAGE_HEIGHT, IMAGE_WIDTH, NUM_CHANNEL]) + image = tf.image.random_flip_left_right(image) + + if data_format == "channels_first": + image = tf.transpose(image, [2, 0, 1]) + + if div255: + image /= 255. + + return image, label + + filename = os.path.join(data_dir, split + ".tfrecords") + dataset = tf.data.TFRecordDataset(filename).repeat(epochs) + dataset = dataset.map(_parser, num_parallel_calls=num_parallel_calls) + + if prefetch: + dataset = dataset.prefetch(batch_size) + if shuffle: + dataset = dataset.shuffle(NUM_TRAIN_IMG) + dataset = dataset.batch(batch_size) + + return dataset diff --git a/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py b/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py new file mode 100644 index 0000000000000000000000000000000000000000..f79428b2a97f0ac2ce991f4c26b9123cddc24325 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py @@ -0,0 +1,123 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Read CIFAR-10 data from pickled numpy arrays and writes TFRecords. + +Generates tf.train.Example protos and writes them to TFRecord files from the +python version of the CIFAR-10 dataset downloaded from +https://www.cs.toronto.edu/~kriz/cifar.html. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import tarfile + +from absl import flags +from six.moves import cPickle as pickle +from six.moves import urllib +import tensorflow as tf + +CIFAR_FILENAME = 'cifar-10-python.tar.gz' +CIFAR_DOWNLOAD_URL = 'https://www.cs.toronto.edu/~kriz/' + CIFAR_FILENAME +CIFAR_LOCAL_FOLDER = 'cifar-10-batches-py' + + +def download_and_extract(data_dir): + """Download CIFAR-10 if not already downloaded.""" + filepath = os.path.join(data_dir, CIFAR_FILENAME) + if tf.gfile.Exists(filepath): + return filepath + if not tf.gfile.Exists(data_dir): + tf.gfile.MakeDirs(data_dir) + + urllib.request.urlretrieve(CIFAR_DOWNLOAD_URL, filepath) + tarfile.open(os.path.join(filepath), 'r:gz').extractall(data_dir) + return filepath + + +def _int64_feature(value): + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + + +def _bytes_feature(value): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + +def _get_file_names(): + """Returns the file names expected to exist in the input_dir.""" + file_names = {} + file_names['train'] = ['data_batch_%d' % i for i in range(1, 5)] + file_names['validation'] = ['data_batch_5'] + file_names['test'] = ['test_batch'] + return file_names + + +def read_pickle_from_file(filename): + with tf.gfile.Open(filename, 'rb') as f: + if sys.version_info >= (3, 0): + data_dict = pickle.load(f, encoding='bytes') + else: + data_dict = pickle.load(f) + return data_dict + + +def convert_to_tfrecord(input_files, output_file): + """Converts files with pickled data to TFRecords.""" + print('Generating %s' % output_file) + with tf.python_io.TFRecordWriter(output_file) as record_writer: + for input_file in input_files: + data_dict = read_pickle_from_file(input_file) + data = data_dict[b'data'] + labels = data_dict[b'labels'] + num_entries_in_batch = len(labels) + + for i in range(num_entries_in_batch): + example = tf.train.Example( + features=tf.train.Features( + feature={ + 'image': _bytes_feature(data[i].tobytes()), + 'label': _int64_feature(labels[i]) + })) + record_writer.write(example.SerializeToString()) + + +def main(_): + print('Download from {} and extract.'.format(CIFAR_DOWNLOAD_URL)) + download_and_extract(FLAGS.data_dir) + file_names = _get_file_names() + input_dir = os.path.join(FLAGS.data_dir, CIFAR_LOCAL_FOLDER) + + for mode, files in file_names.items(): + input_files = [os.path.join(input_dir, f) for f in files] + output_file = os.path.join(FLAGS.data_dir, mode + '.tfrecords') + try: + os.remove(output_file) + except OSError: + pass + convert_to_tfrecord(input_files, output_file) + print('Done!') + + +if __name__ == '__main__': + FLAGS = flags.FLAGS + flags.DEFINE_string( + 'data_dir', + default=None, + help='Directory to download and extract CIFAR-10 to.') + + tf.app.run(main) diff --git a/tensorflow/contrib/eager/python/examples/revnet/config.py b/tensorflow/contrib/eager/python/examples/revnet/config.py new file mode 100644 index 0000000000000000000000000000000000000000..263a65dc768f421ef39091af6a95033c3d83ac2b --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/config.py @@ -0,0 +1,121 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reversible residual network compatible with eager execution. + +Configuration in format of tf.contrib.training.HParams. +Supports CIFAR-10, CIFAR-100, and ImageNet datasets. + +Reference [The Reversible Residual Network: Backpropagation +Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf) + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +tfe = tf.contrib.eager + + +def get_hparams_cifar_38(): + """RevNet-38 configurations for CIFAR-10/CIFAR-100.""" + + config = tf.contrib.training.HParams() + config.add_hparam("init_filters", 32) + config.add_hparam("init_kernel", 3) + config.add_hparam("init_stride", 1) + config.add_hparam("n_classes", 10) + config.add_hparam("n_rev_blocks", 3) + config.add_hparam("n_res", [3, 3, 3]) + config.add_hparam("filters", [32, 64, 112]) + config.add_hparam("strides", [1, 2, 2]) + config.add_hparam("batch_size", 100) + config.add_hparam("bottleneck", False) + config.add_hparam("fused", True) + config.add_hparam("init_max_pool", False) + if tfe.num_gpus() > 0: + config.add_hparam("input_shape", (3, 32, 32)) + config.add_hparam("data_format", "channels_first") + else: + config.add_hparam("input_shape", (32, 32, 3)) + config.add_hparam("data_format", "channels_last") + + # Training details + config.add_hparam("weight_decay", 2e-4) + config.add_hparam("momentum", .9) + config.add_hparam("lr_decay_steps", [40000, 60000]) + config.add_hparam("lr_list", [1e-1, 1e-2, 1e-3]) + config.add_hparam("max_train_iter", 80000) + config.add_hparam("seed", 1234) + config.add_hparam("shuffle", True) + config.add_hparam("prefetch", True) + config.add_hparam("log_every", 50) + config.add_hparam("save_every", 50) + config.add_hparam("dtype", tf.float32) + config.add_hparam("eval_batch_size", 500) + config.add_hparam("div255", True) + config.add_hparam("iters_per_epoch", 50000 // config.batch_size) + config.add_hparam("epochs", config.max_train_iter // config.iters_per_epoch) + + return config + + +def get_hparams_imagenet_56(): + """RevNet-56 configurations for ImageNet.""" + + config = tf.contrib.training.HParams() + config.add_hparam("init_filters", 128) + config.add_hparam("init_kernel", 7) + config.add_hparam("init_stride", 2) + config.add_hparam("n_classes", 1000) + config.add_hparam("n_rev_blocks", 4) + config.add_hparam("n_res", [2, 2, 2, 2]) + config.add_hparam("filters", [128, 256, 512, 832]) + config.add_hparam("strides", [1, 2, 2, 2]) + config.add_hparam("batch_size", 16) + config.add_hparam("bottleneck", True) + config.add_hparam("fused", True) + config.add_hparam("init_max_pool", True) + if tf.test.is_gpu_available(): + config.add_hparam("input_shape", (3, 224, 224)) + config.add_hparam("data_format", "channels_first") + else: + config.add_hparam("input_shape", (224, 224, 3)) + config.add_hparam("data_format", "channels_last") + + # Training details + config.add_hparam("weight_decay", 1e-4) + config.add_hparam("momentum", .9) + config.add_hparam("lr_decay_steps", [160000, 320000, 480000]) + config.add_hparam("lr_list", [1e-1, 1e-2, 1e-3, 1e-4]) + config.add_hparam("max_train_iter", 600000) + config.add_hparam("seed", 1234) + config.add_hparam("shuffle", True) + config.add_hparam("prefetch", True) + config.add_hparam("log_every", 50) + config.add_hparam("save_every", 50) + config.add_hparam("dtype", tf.float32) + config.add_hparam("eval_batch_size", 500) + config.add_hparam("div255", True) + # TODO(lxuechen): Update this according to ImageNet data + config.add_hparam("iters_per_epoch", 50000 // config.batch_size) + config.add_hparam("epochs", config.max_train_iter // config.iters_per_epoch) + + if config.bottleneck: + filters = [f * 4 for f in config.filters] + config.filters = filters + + return config diff --git a/tensorflow/contrib/eager/python/examples/revnet/main.py b/tensorflow/contrib/eager/python/examples/revnet/main.py new file mode 100644 index 0000000000000000000000000000000000000000..9ef11f8e9b470f3bae6b7cfec194774160fc2bd1 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/main.py @@ -0,0 +1,147 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Eager execution workflow with RevNet train on CIFAR-10.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from absl import flags +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import cifar_input +from tensorflow.contrib.eager.python.examples.revnet import config as config_ +from tensorflow.contrib.eager.python.examples.revnet import revnet +tfe = tf.contrib.eager + + +def main(_): + """Eager execution workflow with RevNet trained on CIFAR-10.""" + if FLAGS.data_dir is None: + raise ValueError("No supplied data directory") + + if not os.path.exists(FLAGS.data_dir): + raise ValueError("Data directory {} does not exist".format(FLAGS.data_dir)) + + tf.enable_eager_execution() + config = config_.get_hparams_cifar_38() + model = revnet.RevNet(config=config) + + ds_train = cifar_input.get_ds_from_tfrecords( + data_dir=FLAGS.data_dir, + split="train", + data_aug=True, + batch_size=config.batch_size, + epochs=config.epochs, + shuffle=config.shuffle, + data_format=config.data_format, + dtype=config.dtype, + prefetch=config.prefetch) + + ds_validation = cifar_input.get_ds_from_tfrecords( + data_dir=FLAGS.data_dir, + split="validation", + data_aug=False, + batch_size=config.eval_batch_size, + epochs=1, + data_format=config.data_format, + dtype=config.dtype, + prefetch=config.prefetch) + + ds_test = cifar_input.get_ds_from_tfrecords( + data_dir=FLAGS.data_dir, + split="test", + data_aug=False, + batch_size=config.eval_batch_size, + epochs=1, + data_format=config.data_format, + dtype=config.dtype, + prefetch=config.prefetch) + + global_step = tfe.Variable(1, trainable=False) + + def learning_rate(): # TODO(lxuechen): Remove once cl/201089859 is in place + return tf.train.piecewise_constant(global_step, config.lr_decay_steps, + config.lr_list) + + optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9) + checkpoint = tf.train.Checkpoint( + optimizer=optimizer, model=model, optimizer_step=global_step) + + if FLAGS.train_dir: + summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir) + if FLAGS.restore: + latest_path = tf.train.latest_checkpoint(FLAGS.train_dir) + checkpoint.restore(latest_path) + + for x, y in ds_train: + loss = train_one_iter(model, x, y, optimizer, global_step=global_step) + + if global_step % config.log_every == 0: + it_validation = ds_validation.make_one_shot_iterator() + it_test = ds_test.make_one_shot_iterator() + acc_validation = evaluate(model, it_validation) + acc_test = evaluate(model, it_test) + print("Iter {}, " + "train loss {}, " + "validation accuracy {}, " + "test accuracy {}".format(global_step.numpy(), loss, acc_validation, + acc_test)) + + if FLAGS.train_dir: + with summary_writer.as_default(): + with tf.contrib.summary.always_record_summaries(): + tf.contrib.summary.scalar("Validation accuracy", acc_validation) + tf.contrib.summary.scalar("Test accuracy", acc_test) + tf.contrib.summary.scalar("Training loss", loss) + + if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir: + checkpoint.save(file_prefix=FLAGS.train_dir + "ckpt") + + +def train_one_iter(model, inputs, labels, optimizer, global_step=None): + """Train for one iteration.""" + grads, vars_, loss = model.compute_gradients(inputs, labels, training=True) + optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) + + return loss.numpy() + + +def evaluate(model, iterator): + """Compute accuracy with the given dataset iterator.""" + accuracy = tfe.metrics.Accuracy() + for x, y in iterator: + logits, _ = model(x, training=False) + accuracy( + labels=tf.cast(y, tf.int64), + predictions=tf.argmax(logits, axis=1, output_type=tf.int64)) + + return accuracy.result().numpy() + + +if __name__ == "__main__": + flags.DEFINE_string( + "train_dir", + default=None, + help="[Optional] Directory to store the training information") + flags.DEFINE_string( + "data_dir", default=None, help="Directory to load tfrecords.") + flags.DEFINE_boolean( + "restore", + default=True, + help="[Optional] Restore the latest checkpoint from `train_dir` if True") + FLAGS = flags.FLAGS + tf.app.run(main) diff --git a/tensorflow/contrib/eager/python/examples/revnet/ops.py b/tensorflow/contrib/eager/python/examples/revnet/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9ed5d363e6c8bffd817357c006abee7ac0d1dbba --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/ops.py @@ -0,0 +1,70 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reversible residual network compatible with eager execution. + +Customized basic operations. + +Reference [The Reversible Residual Network: Backpropagation +Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf) +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +def downsample(x, filters, strides, axis=1): + """Downsample feature map with avg pooling, if filter size doesn't match.""" + + def pad_strides(strides, axis=1): + """Convert length 2 to length 4 strides. + + Needed since `tf.layers.Conv2D` uses length 2 strides, whereas operations + such as `tf.nn.avg_pool` use length 4 strides. + + Args: + strides: length 2 list/tuple strides for height and width + axis: integer specifying feature dimension according to data format + Returns: + length 4 strides padded with 1 on batch and channel dimension + """ + + assert len(strides) == 2 + + if axis == 1: + return [1, 1, strides[0], strides[1]] + return [1, strides[0], strides[1], 1] + + assert len(x.shape) == 4 and (axis == 1 or axis == 3) + + data_format = "NCHW" if axis == 1 else "NHWC" + strides_ = pad_strides(strides, axis=axis) + + if strides[0] > 1: + x = tf.nn.avg_pool( + x, strides_, strides_, padding="VALID", data_format=data_format) + + in_filter = x.shape[axis] + out_filter = filters + + if in_filter < out_filter: + pad_size = [(out_filter - in_filter) // 2, (out_filter - in_filter) // 2] + if axis == 1: + x = tf.pad(x, [[0, 0], pad_size, [0, 0], [0, 0]]) + else: + x = tf.pad(x, [[0, 0], [0, 0], [0, 0], pad_size]) + # In case `tape.gradient(x, [x])` produces a list of `None` + return x + 0. diff --git a/tensorflow/contrib/eager/python/examples/revnet/ops_test.py b/tensorflow/contrib/eager/python/examples/revnet/ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5bc2641faf5a5d26262de683e52e36b1f42b3a7b --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/ops_test.py @@ -0,0 +1,80 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for basic ops used in eager mode RevNet.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import ops +tfe = tf.contrib.eager + + +class OpsTest(tf.test.TestCase): + + def test_downsample(self): + """Test `possible_down_sample` function with mock object.""" + + batch_size = 100 + # NHWC format + x = tf.random_normal(shape=[batch_size, 32, 32, 3]) + # HW doesn't change but number of features increased + y = ops.downsample(x, filters=5, strides=(1, 1), axis=3) + self.assertEqual(y.shape, [batch_size, 32, 32, 5]) + # Feature map doesn't change but HW reduced + y = ops.downsample(x, filters=3, strides=(2, 2), axis=3) + self.assertEqual(y.shape, [batch_size, 16, 16, 3]) + # Number of feature increased and HW reduced + y = ops.downsample(x, filters=5, strides=(2, 2), axis=3) + self.assertEqual(y.shape, [batch_size, 16, 16, 5]) + + # Test gradient flow + x = tf.random_normal(shape=[batch_size, 32, 32, 3]) + with tfe.GradientTape() as tape: + tape.watch(x) + y = ops.downsample(x, filters=3, strides=(1, 1)) + self.assertEqual(y.shape, x.shape) + dy = tf.random_normal(shape=[batch_size, 3, 32, 32]) + grad, = tape.gradient(y, [x], output_gradients=[dy]) + self.assertEqual(grad.shape, x.shape) + + # Default NCHW format + if tf.test.is_gpu_available(): + x = tf.random_normal(shape=[batch_size, 3, 32, 32]) + # HW doesn't change but feature map reduced + y = ops.downsample(x, filters=5, strides=(1, 1)) + self.assertEqual(y.shape, [batch_size, 5, 32, 32]) + # Feature map doesn't change but HW reduced + y = ops.downsample(x, filters=3, strides=(2, 2)) + self.assertEqual(y.shape, [batch_size, 3, 16, 16]) + # Both feature map and HW reduced + y = ops.downsample(x, filters=5, strides=(2, 2)) + self.assertEqual(y.shape, [batch_size, 5, 16, 16]) + + # Test gradient flow + x = tf.random_normal(shape=[batch_size, 3, 32, 32]) + with tfe.GradientTape() as tape: + tape.watch(x) + y = ops.downsample(x, filters=3, strides=(1, 1)) + self.assertEqual(y.shape, x.shape) + dy = tf.random_normal(shape=[batch_size, 3, 32, 32]) + grad, = tape.gradient(y, [x], output_gradients=[dy]) + self.assertEqual(grad.shape, x.shape) + + +if __name__ == '__main__': + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py new file mode 100644 index 0000000000000000000000000000000000000000..b3b8c262b1517baa1e65c105db9882b6f7672439 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py @@ -0,0 +1,241 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reversible residual network compatible with eager execution. + +Code for main model. + +Reference [The Reversible Residual Network: Backpropagation +Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import operator + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import blocks + + +class RevNet(tf.keras.Model): + """RevNet that depends on all the blocks.""" + + def __init__(self, config): + """Initialize RevNet with building blocks. + + Args: + config: tf.contrib.training.HParams object; specifies hyperparameters + """ + super(RevNet, self).__init__() + self.axis = 1 if config.data_format == "channels_first" else 3 + self.config = config + + self._init_block = self._construct_init_block() + self._block_list = self._construct_intermediate_blocks() + self._final_block = self._construct_final_block() + + def _construct_init_block(self): + init_block = tf.keras.Sequential( + [ + tf.keras.layers.Conv2D( + filters=self.config.init_filters, + kernel_size=self.config.init_kernel, + strides=(self.config.init_stride, self.config.init_stride), + data_format=self.config.data_format, + use_bias=False, + padding="SAME", + input_shape=self.config.input_shape), + tf.keras.layers.BatchNormalization( + axis=self.axis, fused=self.config.fused), + tf.keras.layers.Activation("relu"), + ], + name="init") + if self.config.init_max_pool: + init_block.add( + tf.keras.layers.MaxPooling2D( + pool_size=(3, 3), + strides=(2, 2), + padding="SAME", + data_format=self.config.data_format)) + return init_block + + def _construct_final_block(self): + f = self.config.filters[-1] # Number of filters + r = functools.reduce(operator.mul, self.config.strides, 1) # Reduce ratio + r *= self.config.init_stride + if self.config.init_max_pool: + r *= 2 + + if self.config.data_format == "channels_first": + w, h = self.config.input_shape[1], self.config.input_shape[2] + input_shape = (f, w // r, h // r) + elif self.config.data_format == "channels_last": + w, h = self.config.input_shape[0], self.config.input_shape[1] + input_shape = (w // r, h // r, f) + else: + raise ValueError("Data format should be either `channels_first`" + " or `channels_last`") + + final_block = tf.keras.Sequential( + [ + tf.keras.layers.BatchNormalization( + axis=self.axis, + input_shape=input_shape, + fused=self.config.fused), + tf.keras.layers.Activation("relu"), + tf.keras.layers.GlobalAveragePooling2D( + data_format=self.config.data_format), + tf.keras.layers.Dense(self.config.n_classes) + ], + name="final") + return final_block + + def _construct_intermediate_blocks(self): + # Precompute input shape after initial block + stride = self.config.init_stride + if self.config.init_max_pool: + stride *= 2 + if self.config.data_format == "channels_first": + w, h = self.config.input_shape[1], self.config.input_shape[2] + input_shape = (self.config.init_filters, w // stride, h // stride) + else: + w, h = self.config.input_shape[0], self.config.input_shape[1] + input_shape = (w // stride, h // stride, self.config.init_filters) + + # Aggregate intermediate blocks + block_list = tf.contrib.checkpoint.List() + for i in range(self.config.n_rev_blocks): + # RevBlock configurations + n_res = self.config.n_res[i] + filters = self.config.filters[i] + if filters % 2 != 0: + raise ValueError("Number of output filters must be even to ensure" + "correct partitioning of channels") + stride = self.config.strides[i] + strides = (self.config.strides[i], self.config.strides[i]) + + # Add block + rev_block = blocks.RevBlock( + n_res, + filters, + strides, + input_shape, + batch_norm_first=(i != 0), # Only skip on first block + data_format=self.config.data_format, + bottleneck=self.config.bottleneck, + fused=self.config.fused) + block_list.append(rev_block) + + # Precompute input shape for the next block + if self.config.data_format == "channels_first": + w, h = input_shape[1], input_shape[2] + input_shape = (filters, w // stride, h // stride) + else: + w, h = input_shape[0], input_shape[1] + input_shape = (w // stride, h // stride, filters) + + return block_list + + def call(self, inputs, training=True): + """Forward pass.""" + + # Only store hidden states during training + if training: + saved_hidden = [inputs] + + h = self._init_block(inputs, training=training) + if training: + saved_hidden.append(h) + + for block in self._block_list: + h = block(h, training=training) + if training: + saved_hidden.append(h) + + logits = self._final_block(h, training=training) + + return (logits, saved_hidden) if training else (logits, None) + + def compute_loss(self, logits, labels): + """Compute cross entropy loss.""" + + cross_ent = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=labels) + + return tf.reduce_mean(cross_ent) + + def compute_gradients(self, inputs, labels, training=True): + """Manually computes gradients. + + Args: + inputs: Image tensor, either NHWC or NCHW, conforming to `data_format` + labels: One-hot labels for classification + training: for batch normalization + + Returns: + list of tuple each being (grad, var) for optimizer use + """ + + # Forward pass record hidden states before downsampling + _, saved_hidden = self.call(inputs, training=training) + + grads_all = [] + vars_all = [] + + # Manually backprop through last block + x = saved_hidden[-1] + with tf.GradientTape() as tape: + x = tf.identity(x) # TODO(lxuechen): Remove after b/110264016 is fixed + tape.watch(x) + logits = self._final_block(x, training=training) + loss = self.compute_loss(logits, labels) + + grads_combined = tape.gradient(loss, + [x] + self._final_block.trainable_variables) + dy, grads_ = grads_combined[0], grads_combined[1:] + grads_all += grads_ + vars_all += self._final_block.trainable_variables + + # Manually backprop through intermediate blocks + for block in reversed(self._block_list): + y = saved_hidden.pop() + x = saved_hidden[-1] + dy, grads, vars_ = block.backward_grads_and_vars( + x, y, dy, training=training) + grads_all += grads + vars_all += vars_ + + # Manually backprop through first block + saved_hidden.pop() + x = saved_hidden.pop() + assert not saved_hidden # Cleared after backprop + + with tf.GradientTape() as tape: + x = tf.identity(x) # TODO(lxuechen): Remove after b/110264016 is fixed + y = self._init_block(x, training=training) + + grads_all += tape.gradient( + y, self._init_block.trainable_variables, output_gradients=[dy]) + vars_all += self._init_block.trainable_variables + + grads_all = self._apply_weight_decay(grads_all, vars_all) + + return grads_all, vars_all, loss + + def _apply_weight_decay(self, grads, vars_): + """Update gradients to reflect weight decay.""" + return [g + self.config.weight_decay * v for g, v in zip(grads, vars_)] diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c712e61858cf7314fe5aefacbdc4dbeb7f0d9fb4 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py @@ -0,0 +1,294 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for basic building blocks used in eager mode RevNet.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gc +import time + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.revnet import config as config_ +from tensorflow.contrib.eager.python.examples.revnet import revnet +from tensorflow.python.client import device_lib +tfe = tf.contrib.eager + + +def train_one_iter(model, inputs, labels, optimizer, global_step=None): + """Train for one iteration.""" + grads, vars_, loss = model.compute_gradients(inputs, labels, training=True) + optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) + + return loss + + +class RevnetTest(tf.test.TestCase): + + def setUp(self): + super(RevnetTest, self).setUp() + config = config_.get_hparams_imagenet_56() + shape = (config.batch_size,) + config.input_shape + self.model = revnet.RevNet(config=config) + self.x = tf.random_normal(shape=shape) + self.t = tf.random_uniform( + shape=[config.batch_size], + minval=0, + maxval=config.n_classes, + dtype=tf.int32) + self.config = config + + def tearDown(self): + del self.model + del self.x + del self.t + del self.config + super(RevnetTest, self).tearDown() + + def test_call(self): + """Test `call` function.""" + + y, _ = self.model(self.x, training=False) + self.assertEqual(y.shape, [self.config.batch_size, self.config.n_classes]) + + def test_compute_gradients(self): + """Test `compute_gradients` function.""" + + grads, vars_, _ = self.model.compute_gradients(inputs=self.x, labels=self.t) + self.assertTrue(isinstance(grads, list)) + self.assertTrue(isinstance(vars_, list)) + self.assertEqual(len(grads), len(vars_)) + for grad, var in zip(grads, vars_): + if grad is not None: + self.assertEqual(grad.shape, var.shape) + + def test_call_defun(self): + """Test `call` function with defun.""" + + y, _ = tfe.defun(self.model.call)(self.x, training=False) + self.assertEqual(y.shape, [self.config.batch_size, self.config.n_classes]) + + def test_compute_gradients_defun(self): + """Test `compute_gradients` function with defun.""" + compute_gradients = tfe.defun(self.model.compute_gradients) + grads, vars_, _ = compute_gradients(self.x, self.t) + self.assertTrue(isinstance(grads, list)) + self.assertTrue(isinstance(vars_, list)) + self.assertEqual(len(grads), len(vars_)) + for grad, var in zip(grads, vars_): + if grad is not None: + self.assertEqual(grad.shape, var.shape) + + def test_training_graph(self): + """Test model training in graph mode.""" + + with tf.Graph().as_default(): + x = tf.random_normal( + shape=(self.config.batch_size,) + self.config.input_shape) + t = tf.random_uniform( + shape=(self.config.batch_size,), + minval=0, + maxval=self.config.n_classes, + dtype=tf.int32) + global_step = tfe.Variable(0., trainable=False) + model = revnet.RevNet(config=self.config) + grads_all, vars_all, _ = model.compute_gradients(x, t, training=True) + optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) + # TODO(lxuechen): This doesn't work due to b/110145168 + with tf.control_dependencies(model.updates): + train_op = optimizer.apply_gradients( + zip(grads_all, vars_all), global_step=global_step) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + for _ in range(1): + sess.run(train_op) + + +# Benchmark related +def device_and_data_format(): + return ("/gpu:0", + "channels_first") if tf.test.is_gpu_available() else ("/cpu:0", + "channels_last") + + +def random_batch(batch_size, config): + shape = (batch_size,) + config.input_shape + images = tf.random_uniform(shape) + labels = tf.random_uniform( + [batch_size], minval=0, maxval=config.n_classes, dtype=tf.int32) + + return images, labels + + +class MockIterator(object): + + def __init__(self, tensors): + self._tensors = [tf.identity(x) for x in tensors] + + def next(self): + return self._tensors + + +class RevnetBenchmark(tf.test.Benchmark): + """Eager and graph benchmarks for RevNet.""" + + def _train_batch_sizes(self): + """Shamelessly copied from `resnet50_test.py`. + + Note: This is targeted towards ImageNet. CIFAR-10 should allow more + aggressive batch sizes. + + Returns: + A tuple of possible batch sizes + """ + for device in device_lib.list_local_devices(): + if tf.DeviceSpec.from_string(device.name).device_type == "GPU": + if "K20" in device.physical_device_desc: + return (16,) + if "P100" in device.physical_device_desc: + return (16, 32, 64) + if tf.DeviceSpec.from_string(device.name).device_type == "TPU": + return (32,) + return (16, 32) + + def _force_device_sync(self): + """Shamelessly copied from `resnet50_test.py`.""" + tf.constant(1.).cpu() + + def _report(self, label, start, num_iters, device, batch_size, data_format): + avg_time = (time.time() - start) / num_iters + dev = tf.DeviceSpec.from_string(device).device_type.lower() + name = "%s_%s_batch_%d_%s" % (label, dev, batch_size, data_format) + extras = {"examples_per_sec": batch_size / avg_time} + self.report_benchmark( + iters=num_iters, wall_time=avg_time, name=name, extras=extras) + + def _benchmark_eager_apply(self, + label, + device_and_format, + defun=False, + execution_mode=None, + compiled=False): + config = config_.get_hparams_imagenet_56() + with tfe.execution_mode(execution_mode): + device, data_format = device_and_format + model = revnet.RevNet(config=config) + if defun: + model.call = tfe.defun(model.call, compiled=compiled) + batch_size = 64 + num_burn = 5 + num_iters = 10 + with tf.device(device): + images, _ = random_batch(batch_size, config) + for _ in range(num_burn): + model(images, training=False) + if execution_mode: + tfe.async_wait() + gc.collect() + start = time.time() + for _ in range(num_iters): + model(images, training=False) + if execution_mode: + tfe.async_wait() + self._report(label, start, num_iters, device, batch_size, data_format) + + def benchmark_eager_apply_sync(self): + self._benchmark_eager_apply( + "eager_apply_sync", device_and_data_format(), defun=False) + + def benchmark_eager_apply_async(self): + self._benchmark_eager_apply( + "eager_apply_async", + device_and_data_format(), + defun=False, + execution_mode=tfe.ASYNC) + + def benchmark_eager_call_defun(self): + self._benchmark_eager_apply( + "eager_apply_with_defun", device_and_data_format(), defun=True) + + def _benchmark_eager_train(self, + label, + make_iterator, + device_and_format, + defun=False, + execution_mode=None, + compiled=False): + config = config_.get_hparams_imagenet_56() + with tfe.execution_mode(execution_mode): + device, data_format = device_and_format + for batch_size in self._train_batch_sizes(): + (images, labels) = random_batch(batch_size, config) + model = revnet.RevNet(config=config) + optimizer = tf.train.GradientDescentOptimizer(0.1) + if defun: + model.call = tfe.defun(model.call) + + num_burn = 3 + num_iters = 10 + with tf.device(device): + iterator = make_iterator((images, labels)) + for _ in range(num_burn): + (images, labels) = iterator.next() + train_one_iter(model, images, labels, optimizer) + if execution_mode: + tfe.async_wait() + self._force_device_sync() + gc.collect() + + start = time.time() + for _ in range(num_iters): + (images, labels) = iterator.next() + train_one_iter(model, images, labels, optimizer) + if execution_mode: + tfe.async_wait() + self._force_device_sync() + self._report(label, start, num_iters, device, batch_size, data_format) + + def benchmark_eager_train_sync(self): + self._benchmark_eager_train( + "eager_train_sync", MockIterator, device_and_data_format(), defun=False) + + def benchmark_eager_train_async(self): + self._benchmark_eager_train( + "eager_train_async", + MockIterator, + device_and_data_format(), + defun=False, + execution_mode=tfe.ASYNC) + + def benchmark_eager_train_defun(self): + self._benchmark_eager_train( + "eager_train", MockIterator, device_and_data_format(), defun=False) + + def benchmark_eager_train_datasets_with_defun(self): + + def make_iterator(tensors): + with tf.device("/device:CPU:0"): + ds = tf.data.Dataset.from_tensors(tensors).repeat() + return tfe.Iterator(ds) + + self._benchmark_eager_train( + "eager_train_dataset_with_defun", + make_iterator, + device_and_data_format(), + defun=True) + + +if __name__ == "__main__": + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/sagan/BUILD b/tensorflow/contrib/eager/python/examples/sagan/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..b470a41d815ce650731680065cc7341f844e3fdc --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/sagan/BUILD @@ -0,0 +1,59 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +# Model +py_library( + name = "config", + srcs = ["config.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "ops", + srcs = ["ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "sagan", + srcs = ["sagan.py"], + srcs_version = "PY2AND3", + deps = [ + ":ops", + "//tensorflow:tensorflow_py", + ], +) + +# Tests +cuda_py_test( + name = "ops_test", + size = "small", + srcs = ["ops_test.py"], + additional_deps = [ + ":ops", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_test( + name = "sagan_test", + size = "large", + srcs = ["sagan_test.py"], + additional_deps = [ + ":config", + ":sagan", + "//tensorflow:tensorflow_py", + ], + tags = [ + "optonly", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/sagan/config.py b/tensorflow/contrib/eager/python/examples/sagan/config.py new file mode 100644 index 0000000000000000000000000000000000000000..1967bbd867447d9deaf9a7cb3b22a38889276a50 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/sagan/config.py @@ -0,0 +1,72 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Self-attention generative adversarial with eager execution. + +Configuration in format of tf.contrib.training.HParams. +Supports default 128x128 ImageNet. + +Reference [Self-Attention Generative Adversarial +Networks](https://arxiv.org/pdf/1805.08318.pdf) + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +tfe = tf.contrib.eager + + +def get_hparams_imagenet(): + """Configurations to train SAGAN on 128x128 ImageNet dataset.""" + config = tf.contrib.training.HParams() + if tf.test.is_gpu_available(): + config.add_hparam("image_shape", (3, 128, 128)) + config.add_hparam("data_format", "channels_first") + config.add_hparam("g_init_shape", (512, 4, 4)) + else: + config.add_hparam("image_shape", (128, 128, 3)) + config.add_hparam("data_format", "channels_first") + config.add_hparam("g_init_shape", (4, 4, 512)) + + config.add_hparam("latent_dim", 128) + config.add_hparam("update_g_once_every", 1) + config.add_hparam("batch_size", 64) + config.add_hparam("d_init_filters", 32) + config.add_hparam("num_upsamples", 5) + # (512, 4, 4) -> (3, 128, 128) + return config + + +def get_hparams_mock(): + """Configurations of smaller networks for testing.""" + config = tf.contrib.training.HParams() + if tf.test.is_gpu_available(): + config.add_hparam("image_shape", (3, 16, 16)) + config.add_hparam("data_format", "channels_first") + config.add_hparam("g_init_shape", (32, 2, 2)) + else: + config.add_hparam("image_shape", (16, 16, 3)) + config.add_hparam("data_format", "channels_last") + config.add_hparam("g_init_shape", (2, 2, 32)) + + config.add_hparam("latent_dim", 16) + config.add_hparam("update_g_once_every", 1) + config.add_hparam("batch_size", 2) + config.add_hparam("d_init_filters", 4) + config.add_hparam("num_upsamples", 3) + # (32, 2, 2) -> (3, 16, 16) + return config diff --git a/tensorflow/contrib/eager/python/examples/sagan/ops.py b/tensorflow/contrib/eager/python/examples/sagan/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9a03cab1d12fc16baa7343f72ac58ccd39f698bc --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/sagan/ops.py @@ -0,0 +1,71 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Self-attention generative adversarial with eager execution. + +Auxiliary operations. + +Reference [Self-Attention Generative Adversarial +Networks](https://arxiv.org/pdf/1805.08318.pdf) +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +def flatten_hw(x, data_format="channels_first"): + """Flatten the input tensor across height and width dimensions.""" + if data_format == "channels_last": + x = tf.transpose(x, perm=[0, 3, 1, 2]) # Convert to `channels_first` + + old_shape = tf.shape(x) + new_shape = [old_shape[0], old_shape[2] * old_shape[3], old_shape[1]] + + return tf.reshape(x, new_shape) + + +def broaden_hw(x, h, w, c, data_format="channels_first"): + """Broaden dimension so that output has height and width.""" + if data_format == "channels_first": + shape = [-1, c, h, w] + else: + shape = [-1, h, w, c] + + return tf.reshape(x, shape) + + +class BroadenHW(tf.keras.layers.Layer): + """Wrapper class so that `broaden_hw` can be used in `tf.keras.Sequential`.""" + + def __init__(self, h, w, c, data_format="channels_first"): + super(BroadenHW, self).__init__() + self.h = h + self.w = w + self.c = c + self.data_format = data_format + + def call(self, x): + return broaden_hw( + x, h=self.h, w=self.w, c=self.c, data_format=self.data_format) + + def compute_output_shape(self, input_shape): + input_shape = tf.TensorShape(input_shape).as_list() + if self.data_format == "channels_first": + output_shape = (input_shape[0], self.c, self.h, self.w) + else: + output_shape = (input_shape[0], self.h, self.w, self.c) + + return tf.TensorShape(output_shape) diff --git a/tensorflow/contrib/eager/python/examples/sagan/ops_test.py b/tensorflow/contrib/eager/python/examples/sagan/ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3454985904215b59d27fc4b76ccb4a8c2c2eff00 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/sagan/ops_test.py @@ -0,0 +1,59 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for auxiliary operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.sagan import ops + + +class OpsTest(tf.test.TestCase): + + def test_flatten_hw(self): + """Test `flatten_hw` function with mock object.""" + + batch_size = 1 + # Default NCHW format + if tf.test.is_gpu_available(): + x = tf.random_normal(shape=(batch_size, 3, 4, 4)) + y = ops.flatten_hw(x, data_format="channels_first") + self.assertEqual(y.shape, (batch_size, 4 * 4, 3)) + + # NHWC format + x = tf.random_normal(shape=(batch_size, 4, 4, 3)) + y = ops.flatten_hw(x, data_format="channels_last") + self.assertEqual(y.shape, (batch_size, 4 * 4, 3)) + + def test_broaden_hw(self): + """Test `broaden_hw` function with mock object.""" + + batch_size = 1 + # NHWC format + x = tf.random_normal(shape=[batch_size, 4 * 4 * 16]) + y = ops.broaden_hw(x, h=4, w=4, c=16, data_format="channels_last") + self.assertEqual(y.shape, (batch_size, 4, 4, 16)) + + # Default NCHW format + if tf.test.is_gpu_available(): + y = ops.broaden_hw(x, h=4, w=4, c=16, data_format="channels_first") + self.assertEqual(y.shape, (batch_size, 16, 4, 4)) + + +if __name__ == "__main__": + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/sagan/sagan.py b/tensorflow/contrib/eager/python/examples/sagan/sagan.py new file mode 100644 index 0000000000000000000000000000000000000000..561be36c911d7145e2d4a5ed12eccd8ceb054f45 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/sagan/sagan.py @@ -0,0 +1,232 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Self-attention generative adversarial with eager execution. + +Code for main model. + +Reference [Self-Attention Generative Adversarial +Networks](https://arxiv.org/pdf/1805.08318.pdf) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.sagan import ops +tfe = tf.contrib.eager + + +class SelfAttentionModule(tf.keras.Model): + """Self-attention module composed of convolutional layers.""" + + def __init__(self, + attention_features, + original_features, + data_format="channels_first"): + """Initialize the module. + + Args: + attention_features: Number of filters for the attention computation. + original_features: Number of filters of the original Tensor. + data_format: Either 'channels_first' or 'channels_last' + """ + super(SelfAttentionModule, self).__init__() + self.data_format = data_format + # Matrix multiplication implemented as 2D Convolution + self.f = tf.keras.layers.Conv2D( + filters=attention_features, + kernel_size=1, + strides=(1, 1), + data_format=data_format) + self.g = tf.keras.layers.Conv2D( + filters=attention_features, + kernel_size=1, + strides=(1, 1), + data_format=data_format) + self.h = tf.keras.layers.Conv2D( + filters=original_features, + kernel_size=1, + strides=(1, 1), + data_format=data_format) + self.scale = tfe.Variable(0., trainable=True) + + def call(self, x): + f = self.f(x) + g = self.g(x) + h = self.h(x) + + f_flatten = ops.flatten_hw(f, data_format=self.data_format) + g_flatten = ops.flatten_hw(g, data_format=self.data_format) + h_flatten = ops.flatten_hw(h, data_format=self.data_format) + + s = tf.matmul(g_flatten, f_flatten, transpose_b=True) + b = tf.nn.softmax(s, axis=-1) + o = tf.matmul(b, h_flatten) + y = self.scale * tf.reshape(o, tf.shape(x)) + x + + return y + + def compute_output_shape(self, input_shape): + return input_shape + + +class SAGAN(tf.contrib.checkpoint.Checkpointable): + """Self-attention generative adversarial network.""" + + def __init__(self, config): + """Initialize the model. + + Args: + config: tf.contrib.training.HParams object; specifies hyperparameters + """ + super(SAGAN, self).__init__() + self.config = config + self.generator = self._construct_generator() + self.discriminator = self._construct_discriminator() + + def _construct_generator(self): + """Construct generator.""" + # TODO(lxuechen): Add spectral normalization for WGAN + axis = 1 if self.config.data_format == "channels_first" else 3 + + generator = tf.keras.Sequential() + generator.add( + tf.keras.layers.InputLayer(input_shape=(self.config.latent_dim,))) + generator.add( + tf.keras.layers.Dense( + units=np.prod(self.config.g_init_shape), activation=tf.nn.relu)) + + if self.config.data_format == "channels_first": + c, h, w = self.config.g_init_shape + else: + h, w, c = self.config.g_init_shape + + # Reshape to NHWC/NCHW + generator.add( + ops.BroadenHW(h=h, w=w, c=c, data_format=self.config.data_format)) + + filters_list = [c // 2**p for p in range(1, self.config.num_upsamples + 1)] + filters_list[-1] = 3 # Standard RGB images + + for filters in filters_list[:len(filters_list) // 2]: + generator.add( + tf.keras.layers.Conv2DTranspose( + filters=filters, + kernel_size=4, + strides=(2, 2), + use_bias=False, + padding="SAME", + data_format=self.config.data_format)) + generator.add(tf.keras.layers.BatchNormalization(axis=axis)) + generator.add(tf.keras.layers.Activation("relu")) + + # pylint: disable=undefined-loop-variable + generator.add( + SelfAttentionModule( + original_features=filters, + attention_features=filters // 8, + data_format=self.config.data_format)) + # pylint: enable=undefined-loop-variable + + for filters in filters_list[len(filters_list) // 2:]: + generator.add( + tf.keras.layers.Conv2DTranspose( + filters=filters, + kernel_size=4, + strides=(2, 2), + use_bias=False, + padding="SAME", + data_format=self.config.data_format)) + if filters == 3: + # Assume Image rescaled to [-1, 1] + generator.add(tf.keras.layers.Activation("tanh")) + else: + generator.add(tf.keras.layers.BatchNormalization(axis=axis)) + generator.add(tf.keras.layers.Activation("relu")) + + return generator + + def _construct_discriminator(self): + """Construct discriminator.""" + # TODO(lxuechen): Add spectral normalization for WGAN + discriminator = tf.keras.Sequential() + discriminator.add( + tf.keras.layers.InputLayer(input_shape=self.config.image_shape)) + + filters_list = [ + self.config.d_init_filters * 2**p + for p in range(self.config.num_upsamples) + ] + + for filters in filters_list[:(len(filters_list) + 1) // 2]: + discriminator.add( + tf.keras.layers.Conv2D( + filters=filters, + kernel_size=4, + strides=(2, 2), + padding="SAME", + data_format=self.config.data_format)) + discriminator.add(tf.keras.layers.LeakyReLU(alpha=.1)) + + # pylint: disable=undefined-loop-variable + discriminator.add( + SelfAttentionModule( + original_features=filters, + attention_features=filters // 8, + data_format=self.config.data_format)) + # pylint: enable=undefined-loop-variable + + for filters in filters_list[(len(filters_list) + 1) // 2:]: + discriminator.add( + tf.keras.layers.Conv2D( + filters=filters, + kernel_size=4, + strides=(2, 2), + padding="SAME", + data_format=self.config.data_format)) + discriminator.add(tf.keras.layers.LeakyReLU(alpha=.1)) + + discriminator.add(tf.keras.layers.Flatten()) + discriminator.add(tf.keras.layers.Dense(units=1)) + + return discriminator + + def compute_loss_and_grads(self, real_images, noise, training=True): + """Compute loss and gradients for both generator and discriminator.""" + # TODO(lxuechen): Add gradient penalty for discriminator + with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape: + real_logits = self.discriminator(real_images, training=training) + + fake_images = self.generator.call(noise, training=training) + fake_logits = self.discriminator.call(fake_images) + + g_loss = self.compute_g_loss(fake_logits) + d_loss = self.compute_d_loss(fake_logits, real_logits) + + g_grads = g_tape.gradient(g_loss, self.generator.trainable_variables) + d_grads = d_tape.gradient(d_loss, self.discriminator.trainable_variables) + + return g_loss, d_loss, g_grads, d_grads + + def compute_g_loss(self, fake_logits): + return -tf.reduce_mean(fake_logits) # Hinge loss + + def compute_d_loss(self, fake_logits, real_logits): + # Hinge loss + real_loss = tf.reduce_mean(tf.nn.relu(1. - real_logits)) + fake_loss = tf.reduce_mean(tf.nn.relu(1. + fake_logits)) + return real_loss + fake_loss diff --git a/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py b/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py new file mode 100644 index 0000000000000000000000000000000000000000..18345945108111b57c5401c26b7dca0bfc8f8316 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py @@ -0,0 +1,101 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for self-attention generative adversarial network.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.sagan import config as config_ +from tensorflow.contrib.eager.python.examples.sagan import sagan +tfe = tf.contrib.eager + + +class SAGANTest(tf.test.TestCase): + + def setUp(self): + super(SAGANTest, self).setUp() + config = config_.get_hparams_mock() + self.noise_shape = (config.batch_size, config.latent_dim) + self.logits_shape = (config.batch_size, 1) + self.images_shape = (config.batch_size,) + config.image_shape + + self.model = sagan.SAGAN(config=config) + self.noise = tf.random_normal(shape=self.noise_shape) + self.real_images = tf.random_normal(shape=self.images_shape) + self.config = config + + def tearDown(self): + del self.model + del self.noise + del self.real_images + super(SAGANTest, self).tearDown() + + def test_generator_call(self): + """Test `generator.__call__` function.""" + fake_images = self.model.generator(self.noise, training=False) + self.assertEqual(fake_images.shape, self.images_shape) + + def test_generator_call_defun(self): + """Test `generator.__call__` function with defun.""" + call_ = tfe.defun(self.model.generator.__call__) + fake_images = call_(self.noise, training=False) + self.assertEqual(fake_images.shape, self.images_shape) + + def test_discriminator_call(self): + """Test `discriminator.__call__` function.""" + real_logits = self.model.discriminator(self.real_images) + self.assertEqual(real_logits.shape, self.logits_shape) + + def test_discriminator_call_defun(self): + """Test `discriminator.__call__` function with defun.""" + call_ = tfe.defun(self.model.discriminator.__call__) + real_logits = call_(self.real_images) + self.assertEqual(real_logits.shape, self.logits_shape) + + def test_compute_loss_and_grads(self): + """Test `compute_loss_and_grads` function.""" + g_loss, d_loss, g_grads, d_grads = self.model.compute_loss_and_grads( + self.real_images, self.noise, training=False) + self.assertEqual(g_loss.shape, ()) + self.assertEqual(d_loss.shape, ()) + self.assertTrue(isinstance(g_grads, list)) + self.assertTrue(isinstance(d_grads, list)) + g_vars = self.model.generator.trainable_variables + d_vars = self.model.discriminator.trainable_variables + + self.assertEqual(len(g_grads), len(g_vars)) + self.assertEqual(len(d_grads), len(d_vars)) + + def test_compute_loss_and_grads_defun(self): + """Test `compute_loss_and_grads` function with defun.""" + compute_loss_and_grads = tfe.defun(self.model.compute_loss_and_grads) + g_loss, d_loss, g_grads, d_grads = compute_loss_and_grads( + self.real_images, self.noise, training=False) + self.assertEqual(g_loss.shape, ()) + self.assertEqual(d_loss.shape, ()) + self.assertTrue(isinstance(g_grads, list)) + self.assertTrue(isinstance(d_grads, list)) + g_vars = self.model.generator.trainable_variables + d_vars = self.model.discriminator.trainable_variables + + self.assertEqual(len(g_grads), len(g_vars)) + self.assertEqual(len(d_grads), len(d_vars)) + + +if __name__ == "__main__": + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index fee9db46fa4f79d7dd613436726e8ddad51faf1c..113aa7967c176b7f4c3cc6f1b12d150fd6149a3a 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -68,6 +68,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@async_clear_error @@run_test_in_graph_and_eager_modes +@@run_all_tests_in_graph_and_eager_modes @@DEVICE_PLACEMENT_EXPLICIT @@DEVICE_PLACEMENT_WARN diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 1937ffb583bc727df76470d072b35fb3c9acaa88..30d297a5fb2dd2f844093d790d051a79105984dd 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -117,7 +117,7 @@ py_library( py_test( name = "dnn_test", - size = "small", + size = "medium", srcs = ["python/estimator/dnn_test.py"], srcs_version = "PY2AND3", tags = [ diff --git a/tensorflow/contrib/estimator/python/estimator/dnn.py b/tensorflow/contrib/estimator/python/estimator/dnn.py index 7ff25b95c079c7e06d29e874bcaa0d2c13e7167e..f1c60a912c8b1daa7db34f46e92bcc36ab300716 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn.py @@ -53,6 +53,13 @@ class DNNEstimator(estimator.Estimator): l1_regularization_strength=0.001 )) + # Or estimator with warm-starting from a previous checkpoint. + estimator = DNNEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3), + feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb], + hidden_units=[1024, 512, 256], + warm_start_from="/path/to/checkpoint/dir") + # Input builders def input_fn_train: # returns x, y pass @@ -92,7 +99,8 @@ class DNNEstimator(estimator.Estimator): activation_fn=nn.relu, dropout=None, input_layer_partitioner=None, - config=None): + config=None, + warm_start_from=None): """Initializes a `DNNEstimator` instance. Args: @@ -116,6 +124,11 @@ class DNNEstimator(estimator.Estimator): input_layer_partitioner: Optional. Partitioner for input layer. Defaults to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. config: `RunConfig` object to configure the runtime settings. + warm_start_from: A string filepath to a checkpoint to warm-start from, or + a `WarmStartSettings` object to fully configure warm-starting. If the + string filepath is provided instead of a `WarmStartSettings`, then all + weights are warm-started, and it is assumed that vocabularies and Tensor + names are unchanged. """ def _model_fn(features, labels, mode, config): return dnn_lib._dnn_model_fn( # pylint: disable=protected-access @@ -131,4 +144,5 @@ class DNNEstimator(estimator.Estimator): input_layer_partitioner=input_layer_partitioner, config=config) super(DNNEstimator, self).__init__( - model_fn=_model_fn, model_dir=model_dir, config=config) + model_fn=_model_fn, model_dir=model_dir, config=config, + warm_start_from=warm_start_from) diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_test.py index 75e3107670d658e55ce23d983e47311f1c180104..050b0428bf7b685229e12561cfb0682d931299d2 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn_test.py @@ -38,7 +38,7 @@ from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache -def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs): +def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg """Returns a DNNEstimator that uses regression_head.""" return dnn.DNNEstimator( head=head_lib.regression_head( @@ -48,6 +48,12 @@ def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs): *args, **kwargs) +def _dnn_estimator_classifier_fn(n_classes=3, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg + """Returns a DNNEstimator that uses multi_class_head.""" + return dnn.DNNEstimator(head=head_lib.multi_class_head(n_classes=n_classes), + *args, **kwargs) + + class DNNEstimatorEvaluateTest( dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase): @@ -75,6 +81,15 @@ class DNNEstimatorTrainTest( self, _dnn_estimator_fn) +class DNNEstimatorWarmStartingTest(dnn_testing_utils.BaseDNNWarmStartingTest, + test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + dnn_testing_utils.BaseDNNWarmStartingTest.__init__( + self, _dnn_estimator_classifier_fn, _dnn_estimator_fn) + + class DNNEstimatorIntegrationTest(test.TestCase): def setUp(self): diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index b798769d2cfde69e9e0b8d65882a07d038cbb994..9594e5132fd20dadea118fd1dd6768feb7fd7fff 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -529,6 +529,7 @@ def multi_label_head(n_classes, applications, the shape is `[batch_size, n_classes]`. Labels can be: + * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]` * An integer `SparseTensor` of class indices. The `dense_shape` must be `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`. diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py index b588f75efe9d0bbf8213a89978a627c0a0ccf554..05bcdac2caa77062f9a8a44a948d2897b439ea1f 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -95,7 +95,7 @@ def sequence_input_layer( Raises: ValueError: If any of the `feature_columns` is the wrong type. """ - feature_columns = fc._clean_feature_columns(feature_columns) + feature_columns = fc._normalize_feature_columns(feature_columns) for c in feature_columns: if not isinstance(c, fc._SequenceDenseColumn): raise ValueError( diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py index 89b5f4c4137f6c42417f539a578fd8b11f8b235d..45d7b740462ca21139e2e93e34b43668f1e08a94 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py @@ -110,7 +110,7 @@ class SequenceInputLayerTest(test.TestCase): expected_sequence_length, sequence_length.eval(session=sess)) def test_embedding_column_with_non_sequence_categorical(self): - """Tests that error is raised for non-sequence categorical column.""" + """Tests that error is raised for non-sequence embedding column.""" vocabulary_size = 3 sparse_input = sparse_tensor.SparseTensorValue( # example 0, ids [2] @@ -132,6 +132,107 @@ class SequenceInputLayerTest(test.TestCase): features={'aaa': sparse_input}, feature_columns=[embedding_column_a]) + def test_shared_embedding_column(self): + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [2, 0] + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 0), + dense_shape=(2, 2)) + + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 4.), # id 1 + (5., 6.) # id 2 + ) + + def _get_initializer(embedding_dimension, embedding_values): + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + return _initializer + + expected_input_layer = [ + # example 0, ids_a [2], ids_b [1] + [[5., 6., 3., 4.], [0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [2, 0] + [[1., 2., 5., 6.], [3., 4., 1., 2.]], + ] + expected_sequence_length = [1, 2] + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + # Test that columns are reordered alphabetically. + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_b, categorical_column_a], + dimension=embedding_dimension, + initializer=_get_initializer(embedding_dimension, embedding_values)) + + input_layer, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + feature_columns=shared_embedding_columns) + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('sequence_input_layer/aaa_bbb_shared_embedding/embedding_weights:0',), + tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_shared_embedding_column_with_non_sequence_categorical(self): + """Tests that error is raised for non-sequence shared embedding column.""" + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + + categorical_column_a = fc.categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = fc.categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2) + + with self.assertRaisesRegexp( + ValueError, + r'In embedding_column: aaa_shared_embedding\. categorical_column must ' + r'be of type _SequenceCategoricalColumn to use sequence_input_layer\.'): + _, _ = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b + }, + feature_columns=shared_embedding_columns) + def test_indicator_column(self): vocabulary_size_a = 3 sparse_input_a = sparse_tensor.SparseTensorValue( @@ -578,6 +679,182 @@ class SequenceEmbeddingColumnTest(test.TestCase): expected_sequence_length, sequence_length.eval(session=sess)) +class SequenceSharedEmbeddingColumnTest(test.TestCase): + + def test_get_sequence_dense_tensor(self): + vocabulary_size = 3 + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 1), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [0, 2] + # example 2, ids [0] + # example 3, ids [] + indices=((0, 0), (1, 0), (1, 1), (2, 0)), + values=(1, 0, 2, 0), + dense_shape=(4, 2)) + + expected_lookups_a = [ + # example 0, ids [2] + [[7., 11.], [0., 0.]], + # example 1, ids [0, 1] + [[1., 2.], [3., 5.]], + # example 2, ids [] + [[0., 0.], [0., 0.]], + # example 3, ids [1] + [[3., 5.], [0., 0.]], + ] + + expected_lookups_b = [ + # example 0, ids [1] + [[3., 5.], [0., 0.]], + # example 1, ids [0, 2] + [[1., 2.], [7., 11.]], + # example 2, ids [0] + [[1., 2.], [0., 0.]], + # example 3, ids [] + [[0., 0.], [0., 0.]], + ] + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], + dimension=embedding_dimension, + initializer=_initializer) + + embedding_lookup_a = shared_embedding_columns[0]._get_sequence_dense_tensor( + _LazyBuilder({ + 'aaa': sparse_input_a + }))[0] + embedding_lookup_b = shared_embedding_columns[1]._get_sequence_dense_tensor( + _LazyBuilder({ + 'bbb': sparse_input_b + }))[0] + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual(('embedding_weights:0',), + tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) + self.assertAllEqual( + expected_lookups_a, embedding_lookup_a.eval(session=sess)) + self.assertAllEqual( + expected_lookups_b, embedding_lookup_b.eval(session=sess)) + + def test_sequence_length(self): + vocabulary_size = 3 + + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + expected_sequence_length_a = [1, 2] + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [0, 2] + # example 1, ids [1] + indices=((0, 0), (0, 1), (1, 0)), + values=(0, 2, 1), + dense_shape=(2, 2)) + expected_sequence_length_b = [2, 1] + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2) + + sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( + _LazyBuilder({ + 'aaa': sparse_input_a + }))[1] + sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor( + _LazyBuilder({ + 'bbb': sparse_input_b + }))[1] + + with monitored_session.MonitoredSession() as sess: + sequence_length_a = sess.run(sequence_length_a) + self.assertAllEqual(expected_sequence_length_a, sequence_length_a) + self.assertEqual(np.int64, sequence_length_a.dtype) + sequence_length_b = sess.run(sequence_length_b) + self.assertAllEqual(expected_sequence_length_b, sequence_length_b) + self.assertEqual(np.int64, sequence_length_b.dtype) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [] + # example 1, ids [2] + # example 2, ids [0, 1] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [] + indices=((1, 0), (2, 0), (2, 1), (4, 0)), + values=(2, 0, 1, 1), + dense_shape=(6, 2)) + expected_sequence_length_a = [0, 1, 2, 0, 1, 0] + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [] + # example 2, ids [] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [0, 1] + indices=((0, 0), (4, 0), (5, 0), (5, 1)), + values=(2, 1, 0, 1), + dense_shape=(6, 2)) + expected_sequence_length_b = [1, 0, 0, 0, 1, 2] + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + + shared_embedding_columns = fc.shared_embedding_columns( + [categorical_column_a, categorical_column_b], dimension=2) + + sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor( + _LazyBuilder({ + 'aaa': sparse_input_a + }))[1] + sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor( + _LazyBuilder({ + 'bbb': sparse_input_b + }))[1] + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length_a, sequence_length_a.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length_b, sequence_length_b.eval(session=sess)) + + class SequenceIndicatorColumnTest(test.TestCase): def test_get_sequence_dense_tensor(self): diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index 40ae01bfcce1dde580e6a5f6d9c8ec1aa1abb83f..e8e318001972934c7d2154bc14744823a3ba09f9 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -712,7 +712,8 @@ class VariableDeviceChooser(object): num_tasks=0, job_name='ps', device_type='CPU', - device_index=0): + device_index=0, + replica=None): """Initialize VariableDeviceChooser. Usage: @@ -733,12 +734,15 @@ class VariableDeviceChooser(object): self._job_name = job_name self._device_type = device_type self._device_index = device_index + self._replica = replica self._num_tasks = num_tasks self._next_task_id = 0 def __call__(self, op): - device_spec = tf_device.DeviceSpec(device_type=self._device_type, - device_index=self._device_index) + device_spec = tf_device.DeviceSpec( + replica=self._replica, + device_type=self._device_type, + device_index=self._device_index) if self._num_tasks > 0: task_id = self._next_task_id self._next_task_id = (self._next_task_id + 1) % self._num_tasks diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py index 37ea6eb12aba7d25656f19cbbc86475c1228d916..7e0c7dbec1d9266b53a169fe83b88d1e3af77d04 100644 --- a/tensorflow/contrib/framework/python/ops/variables_test.py +++ b/tensorflow/contrib/framework/python/ops/variables_test.py @@ -506,6 +506,35 @@ class VariablesTest(test.TestCase): self.assertDeviceEqual(e.device, '/job:ps/task:1/cpu:0') self.assertDeviceEqual(e.initial_value.device, '/cpu:99') + def testVariableWithVariableDeviceChooserWithReplica(self): + + with ops.Graph().as_default(): + device_fn = variables_lib2.VariableDeviceChooser(replica=3, num_tasks=2) + with arg_scope([variables_lib2.variable], device=device_fn): + a = variables_lib2.variable('a', []) + b = variables_lib2.variable('b', []) + c = variables_lib2.variable('c', [], device='cpu:12') + d = variables_lib2.variable('d', []) + with ops.device('cpu:99'): + e_init = constant_op.constant(12) + e = variables_lib2.variable('e', initializer=e_init) + # The values below highlight how the VariableDeviceChooser puts initial + # values on the same device as the variable job. + self.assertDeviceEqual(a.device, '/job:ps/replica:3/task:0/cpu:0') + self.assertEqual(a.initial_value.op.colocation_groups(), + a.op.colocation_groups()) + self.assertDeviceEqual(b.device, '/job:ps/replica:3/task:1/cpu:0') + self.assertEqual(b.initial_value.op.colocation_groups(), + b.op.colocation_groups()) + self.assertDeviceEqual(c.device, '/cpu:12') + self.assertEqual(c.initial_value.op.colocation_groups(), + c.op.colocation_groups()) + self.assertDeviceEqual(d.device, '/job:ps/replica:3/task:0/cpu:0') + self.assertEqual(d.initial_value.op.colocation_groups(), + d.op.colocation_groups()) + self.assertDeviceEqual(e.device, '/job:ps/replica:3/task:1/cpu:0') + self.assertDeviceEqual(e.initial_value.device, '/cpu:99') + def testVariableGPUPlacement(self): with ops.Graph().as_default(): @@ -930,8 +959,8 @@ class AssignFromCheckpointTest(test.TestCase): return saver.save(sess, checkpoint_dir, global_step=global_step) def testLoadExistingVariables(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'load_existing_variables')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'load_existing_variables')) init_value0 = 10.0 init_value1 = 20.0 @@ -944,8 +973,8 @@ class AssignFromCheckpointTest(test.TestCase): var1 = variables_lib2.variable('my_var1', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1} - op, feed_dict = variables_lib2.assign_from_checkpoint(model_path, - vars_to_restore) + op, feed_dict = variables_lib2.assign_from_checkpoint( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -960,8 +989,8 @@ class AssignFromCheckpointTest(test.TestCase): # Tests restoring PartitionedVariables and tests using a dictionary # of lists as the assign_from_checkpoint() var_list param. def testLoadPartitionedVariables(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), 'load_partitioned_variables')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'load_partitioned_variables')) init_value0 = np.array([[10.0, 11.0], [12.0, 13.0]]) init_value1 = np.array([20.0]) # Partitioned into 1 part, edge case. @@ -974,15 +1003,14 @@ class AssignFromCheckpointTest(test.TestCase): partitioner = partitioned_variables.variable_axis_size_partitioner(2) var0 = variables_lib2.variable( 'var0', shape=init_value0.shape, partitioner=partitioner) - var0full = variables_lib2.variable( - 'var0full', shape=init_value0.shape) + var0full = variables_lib2.variable('var0full', shape=init_value0.shape) var1 = variables_lib2.variable( 'var1', shape=init_value1.shape, partitioner=partitioner) # Convert var0 and var1 into a list of underlying variables. vars_to_restore = {'var0': list(var0) + [var0full], 'var1': list(var1)} - op, feed_dict = variables_lib2.assign_from_checkpoint(model_path, - vars_to_restore) + op, feed_dict = variables_lib2.assign_from_checkpoint( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -992,16 +1020,18 @@ class AssignFromCheckpointTest(test.TestCase): # Request and test the variable values. PartitionedVariables can't # be evaled so we wrap them in an identity. - self.assertTrue(np.array_equal( - init_value0, array_ops.identity(var0).eval())) - self.assertTrue(np.array_equal( - init_value0, var0full.eval())) - self.assertTrue(np.array_equal( - init_value1, array_ops.identity(var1).eval())) + self.assertTrue( + np.array_equal(init_value0, + array_ops.identity(var0).eval())) + self.assertTrue(np.array_equal(init_value0, var0full.eval())) + self.assertTrue( + np.array_equal(init_value1, + array_ops.identity(var1).eval())) def testRaisesValueErrorIfAVariableIsntFound(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), 'raises_value_error_if_var_isnt_found')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), + 'raises_value_error_if_var_isnt_found')) init_value0 = 10.0 init_value1 = 20.0 @@ -1019,8 +1049,9 @@ class AssignFromCheckpointTest(test.TestCase): variables_lib2.assign_from_checkpoint(model_path, vars_to_restore) def testInitFromCheckpointWithScopes(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), 'init_from_checkpoint_with_scopes')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), + 'init_from_checkpoint_with_scopes')) init_value0 = np.asarray( [1.0, 3.0, 9.0], dtype=np.float32).reshape((1, 3, 1)) @@ -1038,8 +1069,8 @@ class AssignFromCheckpointTest(test.TestCase): var1 = variables_lib2.variable('my_var1', shape=init_value1.shape) vars_to_restore = {'layer0/v0': var0, 'layer1/v1': var1} - op, feed_dict = variables_lib2.assign_from_checkpoint(model_path, - vars_to_restore) + op, feed_dict = variables_lib2.assign_from_checkpoint( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -1081,8 +1112,8 @@ class AssignFromCheckpointFnTest(test.TestCase): return saver.save(sess, checkpoint_dir, global_step=global_step) def testLoadExistingVariables(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'load_existing_variables')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'load_existing_variables')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1097,8 +1128,8 @@ class AssignFromCheckpointFnTest(test.TestCase): var1 = variables_lib2.variable('my_var1', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1} - init_fn = variables_lib2.assign_from_checkpoint_fn(model_path, - vars_to_restore) + init_fn = variables_lib2.assign_from_checkpoint_fn( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -1111,8 +1142,9 @@ class AssignFromCheckpointFnTest(test.TestCase): self.assertEqual(init_value1, var1.eval()) def testLoadExistingVariablesDifferentShapeDefaultDoesNotAllowReshape(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), 'load_existing_vars_no_reshape')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), + 'load_existing_vars_no_reshape')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1127,8 +1159,8 @@ class AssignFromCheckpointFnTest(test.TestCase): var1 = variables_lib2.variable('my_var1', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1} - init_fn = variables_lib2.assign_from_checkpoint_fn(model_path, - vars_to_restore) + init_fn = variables_lib2.assign_from_checkpoint_fn( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -1138,9 +1170,10 @@ class AssignFromCheckpointFnTest(test.TestCase): init_fn(sess) def testLoadExistingVariablesDifferentShapeAllowReshape(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join( - self.get_temp_dir(), - 'load_existing_variables_different_shape_allow_reshape')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join( + self.get_temp_dir(), + 'load_existing_variables_different_shape_allow_reshape')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1169,8 +1202,8 @@ class AssignFromCheckpointFnTest(test.TestCase): self.assertEqual(init_value1, var1.eval()) def testNotFoundError(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'not_found_error')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'not_found_error')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1186,8 +1219,8 @@ class AssignFromCheckpointFnTest(test.TestCase): var2 = variables_lib2.variable('my_var2', shape=[]) vars_to_restore = {'v0': var0, 'v1': var1, 'v2': var2} - init_fn = variables_lib2.assign_from_checkpoint_fn(model_path, - vars_to_restore) + init_fn = variables_lib2.assign_from_checkpoint_fn( + model_path, vars_to_restore) # Initialize the variables. sess.run(variables_lib.global_variables_initializer()) @@ -1197,8 +1230,8 @@ class AssignFromCheckpointFnTest(test.TestCase): init_fn(sess) def testMissingVariablesList(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'missing_variables_list')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'missing_variables_list')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1228,8 +1261,8 @@ class AssignFromCheckpointFnTest(test.TestCase): self.assertEqual(init_value1, var1.eval()) def testMissingVariablesDict(self): - model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(), - 'missing_variables_dict')) + model_dir = tempfile.mkdtemp( + prefix=os.path.join(self.get_temp_dir(), 'missing_variables_dict')) if gfile.Exists(model_dir): gfile.DeleteRecursively(model_dir) @@ -1279,9 +1312,8 @@ class ZeroInitializerOpTest(test.TestCase): def testZeroInitializer(self): for dtype in (dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64): for use_init in (False, True): - self._testZeroInitializer( - [10, 20], array_ops.ones( - [10, 20], dtype=dtype), use_init) + self._testZeroInitializer([10, 20], array_ops.ones( + [10, 20], dtype=dtype), use_init) class ZeroVarInitializerOpTest(test.TestCase): diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index ff903a78cc36c1965b7655aa902501b1943637a8..5b5557bd8f12b4d42e508f185cb8561eaebea84e 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -24,6 +24,7 @@ from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples from tensorflow.contrib.gan.python import train as tfgan_train from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator.canned import head +from tensorflow.python.estimator.export import export_output from tensorflow.python.framework import ops from tensorflow.python.ops import metrics as metrics_lib @@ -182,7 +183,10 @@ class GANHead(head._Head): # pylint: disable=protected-access if mode == model_fn_lib.ModeKeys.PREDICT: return model_fn_lib.EstimatorSpec( mode=model_fn_lib.ModeKeys.PREDICT, - predictions=gan_model.generated_data) + predictions=gan_model.generated_data, + export_outputs={ + 'predict': export_output.PredictOutput(gan_model.generated_data) + }) elif mode == model_fn_lib.ModeKeys.EVAL: gan_loss = self.create_loss( features=None, mode=mode, logits=gan_model, labels=None) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py index 6587f1fc600b94d27f7c12b44ca2136d0be5a8c5..5309d87765694fa476dae006105e842420a7c437 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py @@ -26,8 +26,11 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test +from tensorflow.python.saved_model import signature_constants from tensorflow.python.training import training +_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + def dummy_loss(gan_model, add_summaries=True): # pylint:disable=unused-argument return math_ops.reduce_sum(gan_model.discriminator_real_outputs - @@ -71,13 +74,15 @@ class GANHeadTest(test.TestCase): return {} def _test_modes_helper(self, mode): - self.gan_head.create_estimator_spec( + return self.gan_head.create_estimator_spec( features=None, mode=mode, logits=get_gan_model()) def test_modes_predict(self): - self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT) + spec = self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT) + self.assertItemsEqual((_DEFAULT_SERVING_KEY, 'predict'), + spec.export_outputs.keys()) def test_modes_eval(self): self._test_modes_helper(model_fn_lib.ModeKeys.EVAL) diff --git a/tensorflow/contrib/gdr/gdr_server_lib.cc b/tensorflow/contrib/gdr/gdr_server_lib.cc index 1f9dd0decb84cf9b7b703f18c061d3c0c7a1cb25..9025c992a4467f521d6d8d514e6a5e92f5492947 100644 --- a/tensorflow/contrib/gdr/gdr_server_lib.cc +++ b/tensorflow/contrib/gdr/gdr_server_lib.cc @@ -57,7 +57,7 @@ Status GdrServer::Init() { new GdrWorker(env, remote_memory_manager_.get())); }; TF_RETURN_IF_ERROR( - GrpcServer::Init(nullptr, rendezvous_mgr_func, worker_func)); + GrpcServer::Init(nullptr, rendezvous_mgr_func, nullptr, worker_func)); return remote_memory_manager_->Init(); } diff --git a/tensorflow/contrib/integrate/python/ops/odes.py b/tensorflow/contrib/integrate/python/ops/odes.py index b4a99867ed46897f60be3f230838c3f576d5455e..61f78febfc07bb4e677259366a81c16b2b585244 100644 --- a/tensorflow/contrib/integrate/python/ops/odes.py +++ b/tensorflow/contrib/integrate/python/ops/odes.py @@ -28,7 +28,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import tensor_array_ops @@ -279,13 +278,27 @@ def _assert_increasing(t): return ops.control_dependencies([assert_increasing]) -def _check_input_types(t, y0): +def _check_input_types(y0, t, dt=None): if not (y0.dtype.is_floating or y0.dtype.is_complex): raise TypeError('`y0` must have a floating point or complex floating ' 'point dtype') if not t.dtype.is_floating: raise TypeError('`t` must have a floating point dtype') + if dt is not None and not dt.dtype.is_floating: + raise TypeError('`dt` must have a floating point dtype') + + +def _check_input_sizes(t, dt): + if len(t.get_shape().as_list()) > 1: + raise ValueError('t must be a 1D tensor') + + if len(dt.get_shape().as_list()) > 1: + raise ValueError('t must be a 1D tensor') + + if t.get_shape()[0] != dt.get_shape()[0] + 1: + raise ValueError('t and dt have incompatible lengths, must be N and N-1') + def _dopri5(func, y0, @@ -510,7 +523,7 @@ def odeint(func, # avoiding the need to pack/unpack in user functions. y0 = ops.convert_to_tensor(y0, name='y0') t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t') - _check_input_types(t, y0) + _check_input_types(y0, t) error_dtype = abs(y0).dtype rtol = ops.convert_to_tensor(rtol, dtype=error_dtype, name='rtol') @@ -530,24 +543,74 @@ def odeint(func, class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)): """Base class for fixed-grid ODE integrators.""" - def integrate(self, evol_func, y0, time_grid): - time_delta_grid = time_grid[1:] - time_grid[:-1] - - scan_func = self._make_scan_func(evol_func) + def integrate(self, evol_func, y0, time_grid, dt_grid, steps_on_intervals): + """Returns integrated values of differential equation on the `time grid`. + + Numerically integrates differential equation defined via time derivative + evaluator `evol_func` using fixed time steps specified in dt_grid. + + Args: + evol_func: Callable, evaluates time derivative of y at a given time. + y0: N-D Tensor holds initial values of the solution. + time_grid: 1-D Tensor holding the time points at which the solution + will be recorded, must have a floating dtype. + dt_grid: 1-D Tensor holds fixed time steps to be used on time_grid + intervals. Must be a floating dtype and have one less element than that + of the time_grid. + steps_on_intervals: 1-D Tensor of integer dtype, must have the same size + as dt_grid. Specifies number of steps needed for every interval. Assumes + steps_on_intervals * dt_grid == time intervals. + + Returns: + (N+1)-D tensor, where the first dimension corresponds to different + time points. Contains the solved value of y for each desired time point in + `t`, with the initial value `y0` being the first element along the first + dimension. + """ - y_grid = functional_ops.scan(scan_func, (time_grid[:-1], time_delta_grid), - y0) - return array_ops.concat([[y0], y_grid], axis=0) + iteration_func = self._make_iteration_func(evol_func, dt_grid) + integrate_interval = self._make_interval_integrator(iteration_func, + steps_on_intervals) - def _make_scan_func(self, evol_func): + num_times = array_ops.size(time_grid) + current_time = time_grid[0] + solution_array = tensor_array_ops.TensorArray(y0.dtype, num_times) + solution_array = solution_array.write(0, y0) - def scan_func(y, t_and_dt): - t, dt = t_and_dt + solution_array, _, _, _ = control_flow_ops.while_loop( + lambda _, __, ___, i: i < num_times, + integrate_interval, + (solution_array, y0, current_time, 1) + ) + solution_array = solution_array.stack() + solution_array.set_shape(time_grid.get_shape().concatenate(y0.get_shape())) + return solution_array + + def _make_iteration_func(self, evol_func, dt_grid): + """Returns a function that builds operations of a single time step.""" + + def iteration_func(y, t, dt_step, interval_step): + """Performs a single time step advance.""" + dt = dt_grid[interval_step - 1] dy = self._step_func(evol_func, t, dt, y) dy = math_ops.cast(dy, dtype=y.dtype) - return y + dy + return y + dy, t + dt, dt_step + 1, interval_step + + return iteration_func + + def _make_interval_integrator(self, iteration_func, interval_sizes): + """Returns a function that builds operations for interval integration.""" - return scan_func + def integrate_interval(solution_array, y, t, interval_num): + """Integrates y with fixed time step on interval `interval_num`.""" + y, t, _, _ = control_flow_ops.while_loop( + lambda _, __, j, interval_num: j < interval_sizes[interval_num - 1], + iteration_func, + (y, t, 0, interval_num) + ) + return solution_array.write(interval_num, y), y, t, interval_num + 1 + + return integrate_interval @abc.abstractmethod def _step_func(self, evol_func, t, dt, y): @@ -555,6 +618,7 @@ class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)): class _MidpointFixedGridIntegrator(_FixedGridIntegrator): + """Fixed grid integrator implementing midpoint scheme.""" def _step_func(self, evol_func, t, dt, y): dt_cast = math_ops.cast(dt, y.dtype) @@ -563,6 +627,7 @@ class _MidpointFixedGridIntegrator(_FixedGridIntegrator): class _RK4FixedGridIntegrator(_FixedGridIntegrator): + """Fixed grid integrator implementing RK4 scheme.""" def _step_func(self, evol_func, t, dt, y): k1 = evol_func(y, t) @@ -575,7 +640,7 @@ class _RK4FixedGridIntegrator(_FixedGridIntegrator): return math_ops.add_n([k1, 2 * k2, 2 * k3, k4]) * (dt_cast / 6) -def odeint_fixed(func, y0, t, method='rk4', name=None): +def odeint_fixed(func, y0, t, dt=None, method='rk4', name=None): """ODE integration on a fixed grid (with no step size control). Useful in certain scenarios to avoid the overhead of adaptive step size @@ -590,6 +655,14 @@ def odeint_fixed(func, y0, t, method='rk4', name=None): `y`. The initial time point should be the first element of this sequence, and each time must be larger than the previous time. May have any floating point dtype. + dt: 0-D or 1-D Tensor providing time step suggestion to be used on time + integration intervals in `t`. 1-D Tensor should provide values + for all intervals, must have 1 less element than that of `t`. + If given a 0-D Tensor, the value is interpreted as time step suggestion + same for all intervals. If passed None, then time step is set to be the + t[1:] - t[:-1]. Defaults to None. The actual step size is obtained by + insuring an integer number of steps per interval, potentially reducing the + time step. method: One of 'midpoint' or 'rk4'. name: Optional name for the resulting operation. @@ -602,16 +675,29 @@ def odeint_fixed(func, y0, t, method='rk4', name=None): Raises: ValueError: Upon caller errors. """ - with ops.name_scope(name, 'odeint_fixed', [y0, t]): + with ops.name_scope(name, 'odeint_fixed', [y0, t, dt]): t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t') y0 = ops.convert_to_tensor(y0, name='y0') - _check_input_types(t, y0) + + intervals = t[1:] - t[:-1] + if dt is None: + dt = intervals + dt = ops.convert_to_tensor(dt, preferred_dtype=dtypes.float64, name='dt') + + steps_on_intervals = math_ops.ceil(intervals / dt) + dt = intervals / steps_on_intervals + steps_on_intervals = math_ops.cast(steps_on_intervals, dtype=dtypes.int32) + + _check_input_types(y0, t, dt) + _check_input_sizes(t, dt) with _assert_increasing(t): with ops.name_scope(method): if method == 'midpoint': - return _MidpointFixedGridIntegrator().integrate(func, y0, t) + return _MidpointFixedGridIntegrator().integrate(func, y0, t, dt, + steps_on_intervals) elif method == 'rk4': - return _RK4FixedGridIntegrator().integrate(func, y0, t) + return _RK4FixedGridIntegrator().integrate(func, y0, t, dt, + steps_on_intervals) else: raise ValueError('method not supported: {!s}'.format(method)) diff --git a/tensorflow/contrib/integrate/python/ops/odes_test.py b/tensorflow/contrib/integrate/python/ops/odes_test.py index 3ec01212d25ca8dc6e13f340177a5e85138868d5..c7b4e2faa84e1a87cb1904b22eb0008ab1ee4be6 100644 --- a/tensorflow/contrib/integrate/python/ops/odes_test.py +++ b/tensorflow/contrib/integrate/python/ops/odes_test.py @@ -242,40 +242,56 @@ class InterpolationTest(test.TestCase): class OdeIntFixedTest(test.TestCase): - def _test_integrate_sine(self, method): + def _test_integrate_sine(self, method, t, dt=None): def evol_func(y, t): del t return array_ops.stack([y[1], -y[0]]) y0 = [0., 1.] - time_grid = np.linspace(0., 10., 200) - y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method) + y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method) with self.test_session() as sess: y_grid_array = sess.run(y_grid) np.testing.assert_allclose( - y_grid_array[:, 0], np.sin(time_grid), rtol=1e-2, atol=1e-2) + y_grid_array[:, 0], np.sin(t), rtol=1e-2, atol=1e-2) - def _test_integrate_gaussian(self, method): + def _test_integrate_gaussian(self, method, t, dt=None): def evol_func(y, t): return -math_ops.cast(t, dtype=y.dtype) * y[0] y0 = [1.] - time_grid = np.linspace(0., 2., 100) - y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method) + y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method) with self.test_session() as sess: y_grid_array = sess.run(y_grid) np.testing.assert_allclose( - y_grid_array[:, 0], np.exp(-time_grid**2 / 2), rtol=1e-2, atol=1e-2) + y_grid_array[:, 0], np.exp(-t**2 / 2), rtol=1e-2, atol=1e-2) + + def _test_integrate_sine_all(self, method): + uniform_time_grid = np.linspace(0., 10., 200) + non_uniform_time_grid = np.asarray([0.0, 0.4, 4.7, 5.2, 7.0]) + uniform_dt = 0.02 + non_uniform_dt = np.asarray([0.01, 0.001, 0.05, 0.03]) + self._test_integrate_sine(method, uniform_time_grid) + self._test_integrate_sine(method, non_uniform_time_grid, uniform_dt) + self._test_integrate_sine(method, non_uniform_time_grid, non_uniform_dt) + + def _test_integrate_gaussian_all(self, method): + uniform_time_grid = np.linspace(0., 2., 100) + non_uniform_time_grid = np.asarray([0.0, 0.1, 0.7, 1.2, 2.0]) + uniform_dt = 0.01 + non_uniform_dt = np.asarray([0.01, 0.001, 0.1, 0.03]) + self._test_integrate_gaussian(method, uniform_time_grid) + self._test_integrate_gaussian(method, non_uniform_time_grid, uniform_dt) + self._test_integrate_gaussian(method, non_uniform_time_grid, non_uniform_dt) def _test_everything(self, method): - self._test_integrate_sine(method) - self._test_integrate_gaussian(method) + self._test_integrate_sine_all(method) + self._test_integrate_gaussian_all(method) def test_midpoint(self): self._test_everything('midpoint') @@ -283,6 +299,21 @@ class OdeIntFixedTest(test.TestCase): def test_rk4(self): self._test_everything('rk4') + def test_dt_size_exceptions(self): + times = np.linspace(0., 2., 100) + dt = np.ones(99) * 0.01 + dt_wrong_length = np.asarray([0.01, 0.001, 0.1, 0.03]) + dt_wrong_dim = np.expand_dims(np.linspace(0., 2., 99), axis=0) + times_wrong_dim = np.expand_dims(np.linspace(0., 2., 100), axis=0) + with self.assertRaises(ValueError): + self._test_integrate_gaussian('midpoint', times, dt_wrong_length) + + with self.assertRaises(ValueError): + self._test_integrate_gaussian('midpoint', times, dt_wrong_dim) + + with self.assertRaises(ValueError): + self._test_integrate_gaussian('midpoint', times_wrong_dim, dt) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops.py b/tensorflow/contrib/labeled_tensor/python/ops/ops.py index 3ba1026383ef146adb32197ae41b5c251155bf46..2ede5daee74223e812cc29e9708b1989b698fb4e 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/ops.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/ops.py @@ -652,7 +652,8 @@ def map_fn(fn, labeled_tensor, name=None): tensor_lt = core.LabeledTensor(tensor, original_axes) return fn(tensor_lt).tensor - map_op = functional_ops.map_fn(tf_fn, labeled_tensor.tensor) + map_op = functional_ops.map_fn( + tf_fn, labeled_tensor.tensor, dtype=first_map_lt.dtype) map_lt = core.LabeledTensor(map_op, final_axes) return core.identity(map_lt, name=scope) diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index 00f03a111ae8be7f49761ef5fb5a82810bcca182..bc3359693562deb1229a78a2db5c256c76f7fd8d 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -19,6 +19,8 @@ See the @{$python/contrib.layers} guide. @@avg_pool2d @@avg_pool3d @@batch_norm +@@convolution +@@convolution1d @@convolution2d @@convolution3d @@conv2d_in_plane diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops.py b/tensorflow/contrib/layers/python/layers/feature_column_ops.py index 06060b99e7e58787994f20f037ffa451abbc7459..a85cff4f7098e9a5eedca1b0c8c0cb42e172d90a 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops.py @@ -683,11 +683,12 @@ def parse_feature_columns_from_sequence_examples( the serialized proto. Returns: - A tuple consisting of: - context_features: a dict mapping `FeatureColumns` from - `context_feature_columns` to their parsed `Tensors`/`SparseTensor`s. - sequence_features: a dict mapping `FeatureColumns` from - `sequence_feature_columns` to their parsed `Tensors`/`SparseTensor`s. + A tuple consisting of (context_features, sequence_features) + + * context_features: a dict mapping `FeatureColumns` from + `context_feature_columns` to their parsed `Tensors`/`SparseTensor`s. + * sequence_features: a dict mapping `FeatureColumns` from + `sequence_feature_columns` to their parsed `Tensors`/`SparseTensor`s. """ # Sequence example parsing requires a single (scalar) example. try: diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index b7194ae33304509a51c2a079bcf89a108f40492b..b6d63c9640611abdda65f1205f544ee505dae1f0 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -57,10 +57,10 @@ from tensorflow.python.training import moving_averages __all__ = [ 'avg_pool2d', 'avg_pool3d', 'batch_norm', 'bias_add', 'conv2d', 'conv3d', 'conv2d_in_plane', 'conv2d_transpose', 'conv3d_transpose', 'convolution', - 'convolution2d', 'convolution2d_in_plane', 'convolution2d_transpose', - 'convolution3d', 'convolution3d_transpose', 'dense_to_sparse', - 'dropout', 'elu', 'flatten', 'fully_connected', 'GDN', 'gdn', - 'images_to_sequence', 'layer_norm', 'linear', 'pool', 'max_pool2d', + 'convolution1d', 'convolution2d', 'convolution2d_in_plane', + 'convolution2d_transpose', 'convolution3d', 'convolution3d_transpose', + 'dense_to_sparse', 'dropout', 'elu', 'flatten', 'fully_connected', 'GDN', + 'gdn', 'images_to_sequence', 'layer_norm', 'linear', 'pool', 'max_pool2d', 'max_pool3d', 'one_hot_encoding', 'relu', 'relu6', 'repeat', 'scale_gradient', 'separable_conv2d', 'separable_convolution2d', 'sequence_to_images', 'softmax', 'spatial_softmax', 'stack', 'unit_norm', diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 56e9194cebbe46907707f7ac0996f9a56fb53c0f..c5c7269b1f15849956e90654e3bcf8ab0eebc393 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1312,6 +1312,29 @@ class ConvolutionInPlaneTest(test.TestCase): self.assertAllClose(result, expected, rtol=1e-5, atol=1e-5) + def testConv1dShape(self): + width = 7 + with self.test_session(): + images = random_ops.random_uniform((5, width, 3), seed=1) + output = layers_lib.convolution1d(images, 32, 3) + self.assertEqual(output.op.name, 'Conv/Relu') + self.assertListEqual(output.get_shape().as_list(), [5, width, 32]) + + def testConvInferSpatialDims(self): + depth, height, width = 7, 9, 11 + with self.test_session(): + images = np.random.uniform(size=(5, width, 4)).astype(np.float32) + output = layers_lib.convolution(images, 32, [3]) + self.assertListEqual(output.get_shape().as_list(), [5, width, 32]) + images = np.random.uniform(size=(5, height, width, 4)).astype(np.float32) + output = layers_lib.convolution(images, 32, [3, 3]) + self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32]) + images = np.random.uniform(size=(5, depth, height, width, + 4)).astype(np.float32) + output = layers_lib.convolution(images, 32, [3, 3, 3]) + self.assertListEqual(output.get_shape().as_list(), + [5, depth, height, width, 32]) + class DenseToSparseTest(test.TestCase): diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index 541da9061732ad271f6d5456446a9c30b81e58dd..f8a3709ee57a32734afa7ac8133271c75d152b2c 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -505,7 +505,7 @@ class Experiment(object): eval_result = None last_warning_time = 0 while (not predicate_fn or predicate_fn( - eval_result, checkpoint_path=previous_path if eval_result else None)): + eval_result, checkpoint_path=previous_path)): # Exit if we have already reached number of steps to train. if self._has_training_stopped(eval_result): logging.info("Exiting continuous eval, global_step=%s >= " diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py index d10927a0cdd5c67c8d2a8e569153235ee175ec4d..fb16c94c29660e2777942ea9cf30da51dbf90571 100644 --- a/tensorflow/contrib/learn/python/learn/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/experiment_test.py @@ -500,7 +500,7 @@ class ExperimentTest(test.TestCase): noop_hook = _NoopHook() def _predicate_fn(eval_result, checkpoint_path): - self.assertEqual(not eval_result, + self.assertEqual(eval_result is None, checkpoint_path is None) return est.eval_count < 3 # pylint: disable=cell-var-from-loop diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 9c804d27854b8004d34c65691b48ca2b0d3bbf7c..8c17c65fcc0dbd58e2b3e9042a983e400cd6c2b9 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -184,6 +184,7 @@ cc_test( deps = [ ":framework", ":string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", "//tensorflow/contrib/lite/kernels:kernel_util", "//tensorflow/contrib/lite/kernels/internal:tensor_utils", "//tensorflow/contrib/lite/schema:schema_fbs", diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/Makefile index cc8a8035d1dadeec98886ba1dae4cdf403f26de4..2b6997146e1e5a3873ed0f94a9221b34bed7621d 100644 --- a/tensorflow/contrib/lite/Makefile +++ b/tensorflow/contrib/lite/Makefile @@ -70,6 +70,12 @@ LIB_PATH := $(LIBDIR)$(LIB_NAME) # A small example program that shows how to link against the library. MINIMAL_PATH := $(BINDIR)minimal +# Benchmark static library and binary +BENCHMARK_LIB_NAME := benchmark-lib.a +BENCHMARK_BINARY_NAME := benchmark_model +BENCHMARK_LIB := $(LIBDIR)$(BENCHMARK_LIB_NAME) +BENCHMARK_BINARY := $(BINDIR)$(BENCHMARK_BINARY_NAME) + MINIMAL_SRCS := \ tensorflow/contrib/lite/examples/minimal/minimal.cc MINIMAL_OBJS := $(addprefix $(OBJDIR), \ @@ -78,12 +84,19 @@ $(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MINIMAL_SRCS)))) # What sources we want to compile, must be kept in sync with the main Bazel # build files. +PROFILER_SRCS := \ + tensorflow/contrib/lite/profiling/time.cc +PROFILE_SUMMARIZER_SRCS := \ + tensorflow/contrib/lite/profiling/profile_summarizer.cc \ + tensorflow/core/util/stats_calculator.cc + CORE_CC_ALL_SRCS := \ $(wildcard tensorflow/contrib/lite/*.cc) \ $(wildcard tensorflow/contrib/lite/kernels/*.cc) \ $(wildcard tensorflow/contrib/lite/kernels/internal/*.cc) \ $(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.cc) \ $(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.cc) \ +$(PROFILER_SRCS) \ $(wildcard tensorflow/contrib/lite/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/internal/*.c) \ @@ -107,18 +120,31 @@ TF_LITE_CC_OBJS := $(addprefix $(OBJDIR), \ $(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS)))) LIB_OBJS := $(TF_LITE_CC_OBJS) + +# Benchmark sources +BENCHMARK_SRCS_DIR := tensorflow/contrib/lite/tools/benchmark +BENCHMARK_ALL_SRCS := $(TFLITE_CC_SRCS) \ + $(wildcard $(BENCHMARK_SRCS_DIR)/*.cc) \ + $(PROFILE_SUMMARIZER_SRCS) + +BENCHMARK_SRCS := $(filter-out \ + $(wildcard $(BENCHMARK_SRCS_DIR)/*_test.cc), \ + $(BENCHMARK_ALL_SRCS)) + +BENCHMARK_OBJS := $(addprefix $(OBJDIR), \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(BENCHMARK_SRCS)))) + # For normal manually-created TensorFlow C++ source files. $(OBJDIR)%.o: %.cc @mkdir -p $(dir $@) $(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@ - # For normal manually-created TensorFlow C++ source files. $(OBJDIR)%.o: %.c @mkdir -p $(dir $@) $(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@ # The target that's compiled if there's no command-line arguments. -all: $(LIB_PATH) $(MINIMAL_PATH) +all: $(LIB_PATH) $(MINIMAL_PATH) $(BENCHMARK_BINARY) # Gathers together all the objects we've compiled into a single '.a' archive. $(LIB_PATH): $(LIB_OBJS) @@ -131,6 +157,21 @@ $(MINIMAL_PATH): $(MINIMAL_OBJS) $(LIB_PATH) -o $(MINIMAL_PATH) $(MINIMAL_OBJS) \ $(LIBFLAGS) $(LIB_PATH) $(LDFLAGS) $(LIBS) + +$(BENCHMARK_LIB) : $(LIB_PATH) $(BENCHMARK_OBJS) + @mkdir -p $(dir $@) + $(AR) $(ARFLAGS) $(BENCHMARK_LIB) $(LIB_OBJS) $(BENCHMARK_OBJS) + +benchmark_lib: $(BENCHMARK_LIB) +$(info $(BENCHMARK_BINARY)) +$(BENCHMARK_BINARY) : $(BENCHMARK_LIB) + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) $(INCLUDES) \ + -o $(BENCHMARK_BINARY) \ + $(LIBFLAGS) $(BENCHMARK_LIB) $(LDFLAGS) $(LIBS) + +benchmark: $(BENCHMARK_BINARY) + # Gets rid of all generated files. clean: rm -rf $(MAKEFILE_DIR)/gen diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc index 4f836d367747e06de682b5764206d33f6e2fb983..22be64d6ff649b4bff45a5e5680984d688a8cf38 100644 --- a/tensorflow/contrib/lite/arena_planner.cc +++ b/tensorflow/contrib/lite/arena_planner.cc @@ -31,7 +31,7 @@ struct AllocationInfo { // The tensor index to be allocated or deallocated. int tensor; // Whether to allocate or deallocate - enum { ALLOC, DEALLOC } type; + enum Type { ALLOC, DEALLOC } type; }; ArenaPlanner::ArenaPlanner(TfLiteContext* context, @@ -67,6 +67,33 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { // Keeps track of references to each tensor. std::vector refcounts(graph_info_->num_tensors(), 0); + // `allocated` and `deallocated` are technically list of boolean values. + // We're saving the compiled binary size by using `vector`. + std::vector allocated(graph_info_->num_tensors(), false); + std::vector deallocated(graph_info_->num_tensors(), false); + + auto allocate = [this, &allocated, &deallocated](int node, + int tensor) -> TfLiteStatus { + if (allocated[tensor]) { + return kTfLiteOk; + } + TF_LITE_ENSURE(context_, !deallocated[tensor]); + alloc_queue_.push_back({node, tensor, AllocationInfo::ALLOC}); + allocated[tensor] = true; + return kTfLiteOk; + }; + + auto deallocate = [this, &allocated, &deallocated]( + int node, int tensor) -> TfLiteStatus { + if (!allocated[tensor]) { + // Do not enqueue a DEALLOC if the tensor is never allocated. + // This happened with the constant tensors. + return kTfLiteOk; + } + TF_LITE_ENSURE(context_, !deallocated[tensor]); + alloc_queue_.push_back({node, tensor, AllocationInfo::DEALLOC}); + return kTfLiteOk; + }; // There will be an entry in alloc_queue_ for the allocation of each tensor // and another for their deallocation. @@ -79,6 +106,28 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { refcounts[tensor_index]++; } + // Variable tensors should are also never overwritten and need to be alive all + // the time. + for (int tensor_index : graph_info_->variables()) { + refcounts[tensor_index]++; + } + + // Queue all graph inputs for allocation. + for (int tensor_index : graph_info_->inputs()) { + if (tensor_index != kOptionalTensor) { + TF_LITE_ENSURE_STATUS(allocate(0, tensor_index)); + } + } + + // Queue all graph variable tensors for allocation. + for (int tensor_index : graph_info_->variables()) { + if (tensor_index != kOptionalTensor) { + // Increase the reference count for input tensors by one, so it will + // never be deallocated. + TF_LITE_ENSURE_STATUS(allocate(0, tensor_index)); + } + } + // Count references to node input tensors. for (int i = 0; i < graph_info_->num_nodes(); ++i) { const TfLiteNode& node = graph_info_->node(i); @@ -94,10 +143,9 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { // Queue all graph inputs for allocation. for (int tensor_index : graph_info_->inputs()) { if (tensor_index != kOptionalTensor) { - alloc_queue_.push_back({0, tensor_index, AllocationInfo::ALLOC}); + TF_LITE_ENSURE_STATUS(allocate(0, tensor_index)); } } - // Go through the graph in execution order. for (int i = 0; i < graph_info_->num_nodes(); ++i) { const TfLiteNode& node = graph_info_->node(i); @@ -106,7 +154,7 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { TfLiteIntArray* node_outputs = node.outputs; for (int j = 0; j < node_outputs->size; ++j) { int tensor_index = node_outputs->data[j]; - alloc_queue_.push_back({i, tensor_index, AllocationInfo::ALLOC}); + TF_LITE_ENSURE_STATUS(allocate(i, tensor_index)); } // Then update the ref-counts of the node's inputs, and if necessary queue @@ -117,7 +165,7 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { if (tensor_index != kOptionalTensor) { refcounts[tensor_index]--; if (refcounts[tensor_index] == 0) { - alloc_queue_.push_back({i, tensor_index, AllocationInfo::DEALLOC}); + TF_LITE_ENSURE_STATUS(deallocate(i, tensor_index)); } } } diff --git a/tensorflow/contrib/lite/arena_planner_test.cc b/tensorflow/contrib/lite/arena_planner_test.cc index a8a8755e2c9e81474f2ff9cd2b85c0eb3d5c3441..f0fd35216f645df59b03340e00daca9322721b1b 100644 --- a/tensorflow/contrib/lite/arena_planner_test.cc +++ b/tensorflow/contrib/lite/arena_planner_test.cc @@ -100,12 +100,18 @@ class TestGraph { std::vector* tensors() { return &tensors_; } const std::vector& inputs() { return inputs_; } const std::vector& outputs() { return outputs_; } + const std::vector& variables() { return variables_; } + + void SetVariables(const std::vector& variables) { + variables_ = variables; + } private: std::vector nodes_; std::vector tensors_; std::vector inputs_; std::vector outputs_; + std::vector variables_; }; // The GraphInfo for a TestGraph. @@ -123,6 +129,9 @@ class TestGraphInfo : public GraphInfo { } const std::vector& inputs() const override { return graph_->inputs(); } const std::vector& outputs() const override { return graph_->outputs(); } + const std::vector& variables() const override { + return graph_->variables(); + } private: TestGraph* graph_; @@ -209,11 +218,8 @@ TEST_F(ArenaPlannerTest, ZeroSizedTensors) { TestGraph graph({1}, {{{1}, {2}, {}}}, {2}); (*graph.tensors())[1].bytes = 0; SetGraph(&graph); - // TODO(ahentz): this is currently broken because the arena finds two - // allocations with the same offset and returns an error. - ASSERT_FALSE(planner_->ExecuteAllocations(0, 10) == kTfLiteOk); - // EXPECT_EQ(GetOffset(1), 0); - // EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + ASSERT_EQ(planner_->ExecuteAllocations(0, 10), kTfLiteOk); + EXPECT_EQ((*graph_->tensors())[1].data.raw, nullptr); } TEST_F(ArenaPlannerTest, SimpleGraph) { @@ -309,13 +315,15 @@ TEST_F(ArenaPlannerTest, SimpleGraphWithPersistentTensor) { { /* in, out, tmp */ {{0, 1}, {2}, {}}, // First op - {{2, 0}, {4}, {5}}, // Second op, with temporary + {{2, 0}, {4}, {5}}, // Second op, with persistent {{4, -1}, {3}, {}} // Third op, with optional }, {3}); // Make #1 persistent so it goes into its own arena. (*graph.tensors())[1].allocation_type = kTfLiteArenaRwPersistent; + // The only use case for kTfLiteArenaRwPersistent is variable tensor now. + graph.SetVariables({1}); SetGraph(&graph); Execute(0, 10); diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index aa6a60dc9ed308c14b360f0f9b9f6ee2c98f0669..81883ba1fd5a2b0bde62b49d67d50dc5a3e281a0 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -204,6 +204,7 @@ def generated_test_models(): "conv", "depthwiseconv", "div", + "equal", "exp", "expand_dims", "floor", @@ -213,12 +214,14 @@ def generated_test_models(): "global_batch_norm", "greater", "greater_equal", + "sum", "l2norm", "l2_pool", "less", "less_equal", "local_response_norm", "log_softmax", + "log", "lstm", "max_pool", "maximum", @@ -226,6 +229,7 @@ def generated_test_models(): "minimum", "mul", "neg", + "not_equal", "pad", "padv2", # "prelu", @@ -234,6 +238,8 @@ def generated_test_models(): "relu6", "reshape", "resize_bilinear", + "rsqrt", + "shape", "sigmoid", "sin", "slice", @@ -242,6 +248,7 @@ def generated_test_models(): "space_to_depth", "sparse_to_dense", "split", + "sqrt", "squeeze", "strided_slice", "strided_slice_1d_exhaustive", diff --git a/tensorflow/contrib/lite/build_ios_universal_lib.sh b/tensorflow/contrib/lite/build_ios_universal_lib.sh index 9f398f4a9f3dcafd7bd49fd5d95e9991b8b36b75..e9531aef19f04adf719156aa3e874dc5ce6e2b04 100755 --- a/tensorflow/contrib/lite/build_ios_universal_lib.sh +++ b/tensorflow/contrib/lite/build_ios_universal_lib.sh @@ -19,22 +19,23 @@ set -e SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" cd "$SCRIPT_DIR/../../.." -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=x86_64 -j 8 \ -$SCRIPT_DIR/gen/lib/ios_x86_64/libtensorflow-lite.a -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=i386 -j 8 \ -$SCRIPT_DIR/gen/lib/ios_i386/libtensorflow-lite.a -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7 -j 8 \ -$SCRIPT_DIR/gen/lib/ios_armv7/libtensorflow-lite.a -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7s -j 8 \ -$SCRIPT_DIR/gen/lib/ios_armv7s/libtensorflow-lite.a -make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=arm64 -j 8 \ -$SCRIPT_DIR/gen/lib/ios_arm64/libtensorflow-lite.a +# Build library for supported architectures and packs them in a fat binary. +make_library() { + for arch in x86_64 i386 armv7 armv7s arm64 + do + make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=${arch} \ + -j 8 \ + $SCRIPT_DIR/gen/lib/ios_${arch}/${1} + done + lipo \ + tensorflow/contrib/lite/gen/lib/ios_x86_64/${1} \ + tensorflow/contrib/lite/gen/lib/ios_i386/${1} \ + tensorflow/contrib/lite/gen/lib/ios_armv7/${1} \ + tensorflow/contrib/lite/gen/lib/ios_armv7s/${1} \ + tensorflow/contrib/lite/gen/lib/ios_arm64/${1} \ + -create \ + -output tensorflow/contrib/lite/gen/lib/${1} +} -lipo \ -tensorflow/contrib/lite/gen/lib/ios_x86_64/libtensorflow-lite.a \ -tensorflow/contrib/lite/gen/lib/ios_i386/libtensorflow-lite.a \ -tensorflow/contrib/lite/gen/lib/ios_armv7/libtensorflow-lite.a \ -tensorflow/contrib/lite/gen/lib/ios_armv7s/libtensorflow-lite.a \ -tensorflow/contrib/lite/gen/lib/ios_arm64/libtensorflow-lite.a \ --create \ --output tensorflow/contrib/lite/gen/lib/libtensorflow-lite.a +make_library libtensorflow-lite.a +make_library benchmark-lib.a diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index c1cc4476fbd45fa6b3f5b3a1ed2cba39cc2ad54b..1b1b8b2985afda669c950eb1284d99d903e95455 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -215,7 +215,7 @@ typedef struct { typedef struct { bool keep_dims; -} TfLiteMeanParams; +} TfLiteReducerParams; typedef struct { int num_splits; @@ -250,6 +250,10 @@ typedef struct { bool validate_indices; } TfLiteSparseToDenseParams; +typedef struct { + TfLiteType out_type; +} TfLiteShapeParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index fc6fdd6eefb4ce9777b4a10ae6bbd8f2ec1ceeaa..7a78206ebf5f7a5e88e56723e874b9d552df05bd 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ // DO NOT EDIT MANUALLY: This file is automatically generated by -// `schema_builtin_ops_header_generator.py`. +// `schema/builtin_ops_header/generator.cc`. #ifdef __cplusplus extern "C" { @@ -96,6 +96,13 @@ typedef enum { kTfLiteBuiltinSparseToDense = 68, kTfLiteBuiltinTile = 69, kTfLiteBuiltinExpandDims = 70, + kTfLiteBuiltinEqual = 71, + kTfLiteBuiltinNotEqual = 72, + kTfLiteBuiltinLog = 73, + kTfLiteBuiltinSum = 74, + kTfLiteBuiltinSqrt = 75, + kTfLiteBuiltinRsqrt = 76, + kTfLiteBuiltinShape = 77, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/context.c b/tensorflow/contrib/lite/context.c index 5c6f5e72a47180cd98be46f60cfa8eaf28197806..7f2aa316f4a9a265b14a216a6ffa53c7f0757426 100644 --- a/tensorflow/contrib/lite/context.c +++ b/tensorflow/contrib/lite/context.c @@ -76,7 +76,7 @@ void TfLiteTensorFree(TfLiteTensor* t) { void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, TfLiteQuantizationParams quantization, char* buffer, size_t size, TfLiteAllocationType allocation_type, - const void* allocation, TfLiteTensor* tensor) { + const void* allocation, bool is_variable, TfLiteTensor* tensor) { TfLiteTensorFree(tensor); tensor->type = type; tensor->name = name; @@ -86,6 +86,7 @@ void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, tensor->bytes = size; tensor->allocation_type = allocation_type; tensor->allocation = allocation; + tensor->is_variable = is_variable; } void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) { diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index 4eb66cc225eb04923be9aaa445a335ad822c8a6f..15a37de9dc665ff147b7094a61a5afab701932ce 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -138,6 +138,7 @@ typedef enum { kTfLiteInt64 = 4, kTfLiteString = 5, kTfLiteBool = 6, + kTfLiteInt16 = 7, } TfLiteType; // Parameters for asymmetric quantization. Quantized values can be converted @@ -148,7 +149,7 @@ typedef struct { int32_t zero_point; } TfLiteQuantizationParams; -// A union of points that points to memory for a given tensor. +// A union of pointers that points to memory for a given tensor. typedef union { int* i32; int64_t* i64; @@ -157,6 +158,7 @@ typedef union { const char* raw_const; uint8_t* uint8; bool* b; + int16_t* i16; } TfLitePtrUnion; // Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped @@ -223,6 +225,9 @@ typedef struct { // delegate buffer. // WARNING: This is an // experimental interface that is subject to change. bool data_is_stale; + + // True if the tensor is a variable. + bool is_variable; } TfLiteTensor; // Free data memory of tensor `t`; @@ -235,7 +240,8 @@ void TfLiteTensorFree(TfLiteTensor* t); void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, TfLiteQuantizationParams quantization, char* buffer, size_t size, TfLiteAllocationType allocation_type, - const void* allocation, TfLiteTensor* tensor); + const void* allocation, bool is_variable, + TfLiteTensor* tensor); // Resize the allocated data of a (dynamic) tensor. void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor); diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc index 0731d14419d2dec2ea5efa48ef5d4b7728af635f..e96ee92376901a341a1f739d0d79727deeb443eb 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc @@ -26,6 +26,10 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" +#ifdef __ANDROID__ +#include +#endif + namespace tflite { namespace { @@ -37,6 +41,29 @@ namespace { return kTfLiteError; \ } +namespace { +int32_t GetAndroidSdkVersion() { +#ifdef __ANDROID__ + const char* sdkProp = "ro.build.version.sdk"; + char sdkVersion[PROP_VALUE_MAX]; + int length = __system_property_get(sdkProp, sdkVersion); + if (length != 0) { + for (int i = 0; i < length; ++i) { + int digit = sdkVersion[i] - '0'; + if (digit < 0 || digit > 9) { + // Non-numeric SDK version, assume it's higher then expected; + return std::numeric_limits::max(); + } + } + return atoi(sdkVersion); + } +#endif // __ANDROID__ + return 0; +} + +static const int32_t kAndroidSdkVersion = GetAndroidSdkVersion(); +} // namespace + // RAII NN API Model Destructor for use with std::unique_ptr struct NNFreeModel { void operator()(ANeuralNetworksModel* model) { @@ -71,7 +98,7 @@ class OperandMapping { // Add a new mapping from `tflite_index` and return the NN API tensor index. int add_new_ann_tensor_index(int tflite_index) { if (tflite_index >= lite_tensor_to_ann_tensor_.size()) { - lite_tensor_to_ann_tensor_.resize(tflite_index + 1); + lite_tensor_to_ann_tensor_.resize(tflite_index + 1, -1); } int new_tensor_index = next_ann_tensor_index_++; lite_tensor_to_ann_tensor_[tflite_index] = new_tensor_index; @@ -98,14 +125,22 @@ class NNAPIOpBuilder { operand_mapping_(tensor_mapping), nn_model_(nn_model) {} - TfLiteStatus AddScalarInt32Operand(int value) { - ANeuralNetworksOperandType operand_type{.type = ANEURALNETWORKS_INT32}; - CHECK_NN(context_, - ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); - int ann_operand = operand_mapping_->add_new_non_tensor_operand(); - CHECK_NN(context_, ANeuralNetworksModel_setOperandValue( - nn_model_, ann_operand, &value, sizeof(int32_t))); - augmented_inputs_.push_back(ann_operand); + TfLiteStatus AddScalarInt32Operand(int32_t value) { + return AddScalarOperand(value, ANEURALNETWORKS_INT32); + } + + TfLiteStatus AddScalarFloat32Operand(float value) { + return AddScalarOperand(value, ANEURALNETWORKS_FLOAT32); + } + + TfLiteStatus AddPoolingParams(void* data) { + auto builtin = reinterpret_cast(data); + AddScalarInt32Operand(builtin->padding); + AddScalarInt32Operand(builtin->stride_width); + AddScalarInt32Operand(builtin->stride_height); + AddScalarInt32Operand(builtin->filter_width); + AddScalarInt32Operand(builtin->filter_height); + AddScalarInt32Operand(builtin->activation); return kTfLiteOk; } @@ -149,7 +184,6 @@ class NNAPIOpBuilder { return kTfLiteOk; case kTfLiteFloat32: nn_type = ANEURALNETWORKS_TENSOR_FLOAT32; - scale = 0.f; break; case kTfLiteUInt8: nn_type = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM; @@ -158,8 +192,8 @@ class NNAPIOpBuilder { break; case kTfLiteInt32: nn_type = ANEURALNETWORKS_TENSOR_INT32; - scale = 0.f; - zeroPoint = 0; + scale = tensor->params.scale; + zeroPoint = tensor->params.zero_point; break; default: context_->ReportError(context_, "Logic error in NN API Delegate.\n"); @@ -192,12 +226,24 @@ class NNAPIOpBuilder { augmented_inputs_.data(), static_cast(augmented_outputs_.size()), augmented_outputs_.data())); - augmented_outputs_.clear(); + augmented_inputs_.clear(); augmented_outputs_.clear(); return kTfLiteOk; } private: + template + TfLiteStatus AddScalarOperand(T value, int32_t nn_type) { + ANeuralNetworksOperandType operand_type{.type = nn_type}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + int ann_operand = operand_mapping_->add_new_non_tensor_operand(); + CHECK_NN(context_, ANeuralNetworksModel_setOperandValue( + nn_model_, ann_operand, &value, sizeof(T))); + augmented_inputs_.push_back(ann_operand); + return kTfLiteOk; + } + // TfLiteContext for error handling. Must be named context for macros to // work. TfLiteContext* context_; @@ -227,29 +273,143 @@ class NNAPIDelegateKernel { // Return a function that knows how to translate a node into its operands // when called. You can use this function to see if a node is supported // (i.e. that MappingFn is not nullptr). - MappingFn Map(TfLiteContext* context, int builtin_code, TfLiteNode* node) { + MappingFn Map(TfLiteContext* context, int builtin_code, int version, + TfLiteNode* node) { switch (builtin_code) { case kTfLiteBuiltinAdd: - return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { - auto builtin = reinterpret_cast(node->builtin_data); - builder->AddScalarInt32Operand(builtin->activation); - return ANEURALNETWORKS_ADD; - }; + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_ADD; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinMul: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_MUL; + }; + } else { + return nullptr; + } break; case kTfLiteBuiltinAveragePool2d: - return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + builder->AddPoolingParams(node->builtin_data); + return ANEURALNETWORKS_AVERAGE_POOL_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinMaxPool2d: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + builder->AddPoolingParams(node->builtin_data); + return ANEURALNETWORKS_MAX_POOL_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinL2Pool2d: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + builder->AddPoolingParams(node->builtin_data); + return ANEURALNETWORKS_L2_POOL_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinConv2d: + if (version == 1) { auto builtin = - reinterpret_cast(node->builtin_data); - builder->AddScalarInt32Operand(builtin->padding); - builder->AddScalarInt32Operand(builtin->stride_width); - builder->AddScalarInt32Operand(builtin->stride_height); - builder->AddScalarInt32Operand(builtin->filter_width); - builder->AddScalarInt32Operand(builtin->filter_height); - builder->AddScalarInt32Operand(builtin->activation); - return ANEURALNETWORKS_AVERAGE_POOL_2D; - }; + reinterpret_cast(node->builtin_data); + if (builtin->dilation_width_factor != 1 || + builtin->dilation_height_factor != 1 || node->inputs->size != 3) { + // NNAPI does not support dilated Conv2D. + return nullptr; + } + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->padding); + builder->AddScalarInt32Operand(builtin->stride_width); + builder->AddScalarInt32Operand(builtin->stride_height); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_CONV_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinDepthwiseConv2d: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + node->builtin_data); + builder->AddScalarInt32Operand(builtin->padding); + builder->AddScalarInt32Operand(builtin->stride_width); + builder->AddScalarInt32Operand(builtin->stride_height); + builder->AddScalarInt32Operand(builtin->depth_multiplier); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_DEPTHWISE_CONV_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinFullyConnected: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_FULLY_CONNECTED; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinSoftmax: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarFloat32Operand(builtin->beta); + return ANEURALNETWORKS_SOFTMAX; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinReshape: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_RESHAPE; + }; + } else { + return nullptr; + } break; default: return nullptr; @@ -292,10 +452,14 @@ class NNAPIDelegateKernel { int relative_input_index = 0; for (auto absolute_input_index : TfLiteIntArrayView(node->inputs)) { TfLiteTensor* tensor = &context->tensors[absolute_input_index]; - CHECK_NN(context, ANeuralNetworksExecution_setInput( - execution, relative_input_index, nullptr, - tensor->data.raw, tensor->bytes)); - relative_input_index++; + // TODO(miaowang): make sure the delegation works with dequantized weights + // as intermediate tensors. + if (tensor->allocation_type != kTfLiteMmapRo) { + CHECK_NN(context, ANeuralNetworksExecution_setInput( + execution, relative_input_index, nullptr, + tensor->data.raw, tensor->bytes)); + relative_input_index++; + } } // Set the output tensor buffers. @@ -345,8 +509,8 @@ class NNAPIDelegateKernel { TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index)); } // Get op type and operands - int nn_op_type = - Map(context, reg->builtin_code, node)(context, &builder, node); + int nn_op_type = Map(context, reg->builtin_code, reg->version, node)( + context, &builder, node); // Map outputs to NN API tensor indices. for (auto output_index : TfLiteIntArrayView(node->outputs)) { TF_LITE_ENSURE_STATUS(builder.AddTensorOutput(output_index)); @@ -368,8 +532,12 @@ class NNAPIDelegateKernel { std::vector outputs; outputs.reserve(output_tensors->size); // Make the TensorFlow lite inputs and outputs to ann_indices. - for (int i : TfLiteIntArrayView(input_tensors)) - inputs.push_back(operand_mapping_.lite_index_to_ann(i)); + for (int i : TfLiteIntArrayView(input_tensors)) { + // Constant tensors are not NNAPI inputs. + if (context->tensors[i].allocation_type != kTfLiteMmapRo) { + inputs.push_back(operand_mapping_.lite_index_to_ann(i)); + } + } for (int i : TfLiteIntArrayView(output_tensors)) outputs.push_back(operand_mapping_.lite_index_to_ann(i)); // Tell ANN to declare inputs/outputs @@ -392,7 +560,8 @@ TfLiteDelegate* NnApiDelegate() { .Prepare = [](TfLiteContext* context, TfLiteDelegate* delegate) -> TfLiteStatus { // Do not check nodes_ if NN API is unavailable. - if (!NNAPIExists()) return kTfLiteOk; + // NN API is only available since Android O-MR1 (API 27). + if (kAndroidSdkVersion < 27 || !NNAPIExists()) return kTfLiteOk; std::vector supported_nodes(1); // We don't care about all nodes_, we only care about ones in the @@ -400,6 +569,7 @@ TfLiteDelegate* NnApiDelegate() { TfLiteIntArray* plan; TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); int total_supported_nodes = 0; + // Check for every node if it is supported // TODO(b/80625235): Fix this to do more careful checking of versioning. for (int node_index : TfLiteIntArrayView(plan)) { @@ -408,7 +578,8 @@ TfLiteDelegate* NnApiDelegate() { TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( context, node_index, &node, ®istration)); NNAPIDelegateKernel dummy_kernel; - if (dummy_kernel.Map(context, registration->builtin_code, node)) { + if (dummy_kernel.Map(context, registration->builtin_code, + registration->version, node)) { supported_nodes.push_back(node_index); } total_supported_nodes += 1; diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc index ff2e721423f07889f36746a2889afcc3369f28fc..799e3efe0bb09b242d8e5b1d15d7a9646965a85d 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc @@ -21,8 +21,12 @@ limitations under the License. namespace tflite { namespace { +using ::testing::ElementsAre; using ::testing::ElementsAreArray; +// TODO(b/110368244): figure out how to share the existing tests in kernels/ but +// with the delegation on. Also, add more unit tests to improve code coverage. + class FloatAddOpModel : public SingleOpModel { public: FloatAddOpModel(const TensorData& input1, const TensorData& input2, @@ -72,6 +76,535 @@ TEST(NNAPIDelegate, AddWithRelu) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({0.0, 0.4, 1.0, 1.3})); } +class FloatMulOpModel : public SingleOpModel { + public: + FloatMulOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, + ActivationFunctionType activation_type) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_MUL, BuiltinOptions_MulOptions, + CreateMulOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input1_; + int input2_; + int output_; +}; + +TEST(NNAPIDelegate, MulWithNoActivation) { + FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4}))); +} + +class FloatPoolingOpModel : public SingleOpModel { + public: + FloatPoolingOpModel(BuiltinOperator type, const TensorData& input, + int filter_width, int filter_height, + const TensorData& output) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + input_ = AddInput(input); + output_ = AddOutput(output); + + SetBuiltinOp( + type, BuiltinOptions_Pool2DOptions, + CreatePool2DOptions(builder_, Padding_VALID, 2, 2, filter_width, + filter_height, ActivationFunctionType_NONE) + .Union()); + + BuildInterpreter({GetShape(input_)}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int output_; +}; + +TEST(NNAPIDelegate, AveragePoolWithNoActivation) { + FloatPoolingOpModel m(BuiltinOperator_AVERAGE_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2.75, 5.75})); +} + +TEST(NNAPIDelegate, MaxPoolWithNoActivation) { + FloatPoolingOpModel m(BuiltinOperator_MAX_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 10})); +} + +TEST(NNAPIDelegate, L2PoolWithNoActivation) { + FloatPoolingOpModel m(BuiltinOperator_L2_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.5, 6.5})); +} + +class BaseConvolutionOpModel : public SingleOpModel { + public: + BaseConvolutionOpModel( + const TensorData& input, const TensorData& filter, + const TensorData& output, int stride_width = 2, int stride_height = 2, + enum Padding padding = Padding_VALID, + enum ActivationFunctionType activation = ActivationFunctionType_NONE, + int dilation_width_factor = 1, int dilation_height_factor = 1) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + input_ = AddInput(input); + filter_ = AddInput(filter); + + int bias_size = GetShape(filter_)[0]; + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {bias_size}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(filter_); + TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + if (input.type != TensorType_FLOAT32) { + // The following is required by quantized inference. It is the unittest's + // responsibility to make sure the output scale falls into the correct + // range. + CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_)); + } + + SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions, + CreateConv2DOptions( + builder_, padding, stride_width, stride_height, activation, + dilation_width_factor, dilation_height_factor) + .Union()); + + BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); + } + + protected: + int input_; + int filter_; + int bias_; + int output_; +}; + +class ConvolutionOpModel : public BaseConvolutionOpModel { + public: + using BaseConvolutionOpModel::BaseConvolutionOpModel; + + void SetFilter(std::initializer_list f) { PopulateTensor(filter_, f); } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } +}; + +class QuantizedConvolutionOpModel : public BaseConvolutionOpModel { + public: + using BaseConvolutionOpModel::BaseConvolutionOpModel; + + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + void SetFilter(std::initializer_list data) { + QuantizeAndPopulate(filter_, data); + } + + void SetBias(std::initializer_list data) { + QuantizeAndPopulate(bias_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +// In this tests we set the input and output scales so that the results +// match exactly the 'non-quantized' version. +TEST(NNAPIDelegate, SimpleTestQuantized) { + QuantizedConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64}, + {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64}, + {TensorType_UINT8, {}, -127, 128}); + m.SetInput({ + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetFilter({ + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 37, 4, 3, // second batch, right + }, + 1e-5))); + // For good measure, let's also verify the quantized values: + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 145, 129, 132, // + 145, 129, 132, // + 144, 131, 130, // + 164, 131, 130, // + })); +} + +TEST(NNAPIDelegate, Conv2DWithNoActivation) { + ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}}, + {TensorType_FLOAT32, {3, 2, 2, 1}}, + {TensorType_FLOAT32, {}}); + + m.SetInput({ + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetFilter({ + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 37, 4, 3, // second batch, right + })); +} + +class DepthwiseConvolutionOpModel : public SingleOpModel { + public: + DepthwiseConvolutionOpModel(const TensorData& input, const TensorData& filter, + const TensorData& output) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + input_ = AddInput(input); + filter_ = AddInput(filter); + + int bias_size = GetShape(filter_)[3]; + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {bias_size}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(filter_); + TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + + int input_depth = GetShape(input_)[3]; + int output_depth = GetShape(filter_)[3]; + int depth_mul = output_depth / input_depth; + + SetBuiltinOp( + BuiltinOperator_DEPTHWISE_CONV_2D, + BuiltinOptions_DepthwiseConv2DOptions, + CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul, + ActivationFunctionType_NONE) + .Union()); + + BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); + } + + void SetFilter(std::initializer_list f) { PopulateTensor(filter_, f); } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int filter_; + int bias_; + int output_; +}; + +TEST(NNAPIDelegate, DepthwiseConv2DWithNoActivation) { + DepthwiseConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 2, 2}}, + {TensorType_FLOAT32, {1, 2, 2, 4}}, + {TensorType_FLOAT32, {}}); + + m.SetInput({ + 1, 2, 7, 8, // column 1 + 3, 4, 9, 10, // column 2 + 5, 6, 11, 12, // column 3 + }); + m.SetFilter({ + 1, 2, 3, 4, // + -9, 10, -11, 12, // + 5, 6, 7, 8, // + 13, -14, 15, -16, // + }); + m.SetBias({1, 2, 3, 4}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 71, -34, 99, -20, // + 91, -26, 127, -4, // + })); +} + +class FloatFullyConnectedOpModel : public SingleOpModel { + public: + FloatFullyConnectedOpModel(int units, int batches, const TensorData& input, + const TensorData& output = {TensorType_FLOAT32}) + : batches_(batches), units_(units) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + int total_input_size = 1; + for (int i = 0; i < input.shape.size(); ++i) { + total_input_size *= input.shape[i]; + } + input_size_ = total_input_size / batches_; + + input_ = AddInput(input); + weights_ = + AddInput({input.type, {units_, input_size_}, input.min, input.max}); + + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {units_}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(weights_); + TensorData bias{TensorType_INT32, {units_}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + + SetBuiltinOp( + BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions, + CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU) + .Union()); + BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)}); + } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetWeights(std::initializer_list f) { + PopulateTensor(weights_, f); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int weights_; + int bias_; + int output_; + + int batches_; + int units_; + int input_size_; +}; + +TEST(NNAPIDelegate, FullyConnectedSimpleTest) { + FloatFullyConnectedOpModel m(/*units=*/3, /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {2, 10}}); + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60)); +} + +class SoftmaxOpModel : public SingleOpModel { + public: + SoftmaxOpModel(int batches, int size, float beta) + : batches_(batches), input_size_(size), beta_(beta) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions, + CreateSoftmaxOptions(builder_, beta_).Union()); + BuildInterpreter({{batches_, input_size_}}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; + + int batches_; + int input_size_; + float beta_; +}; + +TEST(NNAPIDelegate, SoftmaxSimpleTest) { + SoftmaxOpModel m(/*batches=*/2, /*size=*/5, /*beta=*/1.0); + m.SetInput({ + 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 + -1.0, -2.0, -3.0, -4.0, -5.0, // b = 0 + }); + + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0.011656231, 0.031684921, 0.086128544, 0.234121657, 0.636408647, + 0.636408647, 0.234121657, 0.086128544, 0.031684921, 0.011656231}, + 1e-6))); +} + +class ReshapeOpModel : public SingleOpModel { + public: + ReshapeOpModel(std::initializer_list input_shape, + std::initializer_list new_shape) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + input_ = AddInput(TensorType_FLOAT32); + new_shape_ = AddInput(TensorType_INT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions, + CreateReshapeOptions(builder_, builder_.CreateVector(new_shape)) + .Union()); + BuildInterpreter({input_shape, {static_cast(new_shape.size())}}); + PopulateTensor(new_shape_, new_shape); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int new_shape_; + int output_; +}; + +TEST(NNAPIDelegate, ReshapeSimpleTest) { + ReshapeOpModel m({1, 2, 4, 1}, {2, 2, 2}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/examples/android/BUILD b/tensorflow/contrib/lite/examples/android/BUILD index 57000072561303e8457f61b1ebe95d382fc01f10..dd2cd173246719976d7cd6e52d65f63125b5b2db 100644 --- a/tensorflow/contrib/lite/examples/android/BUILD +++ b/tensorflow/contrib/lite/examples/android/BUILD @@ -1,6 +1,8 @@ # Description: # TensorFlow camera demo app for Android. +load("@build_bazel_rules_android//android:rules.bzl", "android_binary") + package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 @@ -24,28 +26,28 @@ cc_library( android_binary( name = "tflite_demo", srcs = glob([ - "src/**/*.java", + "app/src/main/java/**/*.java", ]), # Package assets from assets dir as well as all model targets. # Remove undesired models (and corresponding Activities in source) # to reduce APK size. assets = [ - "//tensorflow/contrib/lite/examples/android/assets:labels_mobilenet_quant_v1_224.txt", + "//tensorflow/contrib/lite/examples/android/app/src/main/assets:labels_mobilenet_quant_v1_224.txt", "@tflite_mobilenet//:mobilenet_quant_v1_224.tflite", "@tflite_conv_actions_frozen//:conv_actions_frozen.tflite", - "//tensorflow/contrib/lite/examples/android/assets:conv_actions_labels.txt", + "//tensorflow/contrib/lite/examples/android/app/src/main/assets:conv_actions_labels.txt", "@tflite_mobilenet_ssd//:mobilenet_ssd.tflite", - "//tensorflow/contrib/lite/examples/android/assets:box_priors.txt", - "//tensorflow/contrib/lite/examples/android/assets:coco_labels_list.txt", + "//tensorflow/contrib/lite/examples/android/app/src/main/assets:box_priors.txt", + "//tensorflow/contrib/lite/examples/android/app/src/main/assets:coco_labels_list.txt", ], assets_dir = "", custom_package = "org.tensorflow.lite.demo", inline_constants = 1, - manifest = "AndroidManifest.xml", + manifest = "app/src/main/AndroidManifest.xml", nocompress_extensions = [ ".tflite", ], - resource_files = glob(["res/**"]), + resource_files = glob(["app/src/main/res/**"]), tags = [ "manual", "notap", @@ -55,31 +57,3 @@ android_binary( "//tensorflow/contrib/lite/java:tensorflowlite", ], ) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - "bin/**", - "gen/**", - "gradleBuild/**", - "libs/**", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - -filegroup( - name = "java_files", - srcs = glob(["src/**/*.java"]), -) - -filegroup( - name = "resource_files", - srcs = glob(["res/**"]), -) - -exports_files(["AndroidManifest.xml"]) diff --git a/tensorflow/contrib/lite/examples/android/android.iml b/tensorflow/contrib/lite/examples/android/android.iml new file mode 100644 index 0000000000000000000000000000000000000000..f0a5ac2bf4cdfb7c98f5704310fbf2f16e9065a2 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/android.iml @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tensorflow/contrib/lite/examples/android/app/build.gradle b/tensorflow/contrib/lite/examples/android/app/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..8e0a98ed63f99b7477cdb2f851a19cd31b45f314 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/app/build.gradle @@ -0,0 +1,60 @@ +apply plugin: 'com.android.application' + +android { + compileSdkVersion 26 + buildToolsVersion '26.0.2' + defaultConfig { + applicationId "org.tensorflow.lite.demo" + minSdkVersion 15 + targetSdkVersion 26 + versionCode 1 + versionName "1.0" + testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" + + // Remove this block. + jackOptions { + enabled true + } + } + lintOptions { + abortOnError false + } + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' + } + } + aaptOptions { + noCompress "tflite" + } + + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } +} + +repositories { + maven { + url 'https://google.bintray.com/tensorflow' + } +} + +// import DownloadModels task +project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets' +project.ext.TMP_DIR = project.buildDir.toString() + '/downloads' + +// Download default models; if you wish to use your own models then +// place them in the "assets" directory and comment out this line. +apply from: "download-models.gradle" + +dependencies { + compile fileTree(dir: 'libs', include: ['*.jar']) + androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { + exclude group: 'com.android.support', module: 'support-annotations' + }) + compile 'org.tensorflow:tensorflow-lite:0.0.0-nightly' + + testCompile 'junit:junit:4.12' +} diff --git a/tensorflow/contrib/lite/examples/android/app/download-models.gradle b/tensorflow/contrib/lite/examples/android/app/download-models.gradle new file mode 100644 index 0000000000000000000000000000000000000000..8e65dc076f2a8daaddf01ceab6796b8ed1127af3 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/app/download-models.gradle @@ -0,0 +1,73 @@ +/* + * download-models.gradle + * Downloads model files from ${MODEL_URL} into application's asset folder + * Input: + * project.ext.TMP_DIR: absolute path to hold downloaded zip files + * project.ext.ASSET_DIR: absolute path to save unzipped model files + * Output: + * 3 model files will be downloaded into given folder of ext.ASSET_DIR + */ +// hard coded model files +// LINT.IfChange + +def models = ['conv_actions_tflite.zip', + 'mobilenet_ssd_tflite_v1.zip', + 'mobilenet_v1_224_android_quant_2017_11_08.zip'] +// LINT.ThenChange(//tensorflow/examples/android/BUILD) + +// Root URL for model archives +def MODEL_URL = 'https://storage.googleapis.com/download.tensorflow.org/models/tflite' + +buildscript { + repositories { + jcenter() + } + dependencies { + classpath 'de.undercouch:gradle-download-task:3.2.0' + } +} + +import de.undercouch.gradle.tasks.download.Download +task downloadFile(type: Download){ + for (f in models) { + def modelUrl = MODEL_URL + "/" + f + println "Downloading ${f} from ${modelUrl}" + src modelUrl + } + + dest new File(project.ext.TMP_DIR) + overwrite true +} + +task extractModels(type: Copy) { + for (f in models) { + def localFile = f.split("/")[-1] + from zipTree(project.ext.TMP_DIR + '/' + localFile) + } + + into file(project.ext.ASSET_DIR) + fileMode 0644 + exclude '**/LICENSE' + + def needDownload = false + for (f in models) { + def localFile = f.split("/")[-1] + if (!(new File(project.ext.TMP_DIR + '/' + localFile)).exists()) { + needDownload = true + } + } + + if (needDownload) { + dependsOn downloadFile + } +} + +tasks.whenTaskAdded { task -> + if (task.name == 'assembleDebug') { + task.dependsOn 'extractModels' + } + if (task.name == 'assembleRelease') { + task.dependsOn 'extractModels' + } +} + diff --git a/tensorflow/contrib/lite/examples/android/AndroidManifest.xml b/tensorflow/contrib/lite/examples/android/app/src/main/AndroidManifest.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/AndroidManifest.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/AndroidManifest.xml diff --git a/tensorflow/contrib/lite/examples/android/assets/BUILD b/tensorflow/contrib/lite/examples/android/app/src/main/assets/BUILD similarity index 100% rename from tensorflow/contrib/lite/examples/android/assets/BUILD rename to tensorflow/contrib/lite/examples/android/app/src/main/assets/BUILD diff --git a/tensorflow/contrib/lite/examples/android/assets/box_priors.txt b/tensorflow/contrib/lite/examples/android/app/src/main/assets/box_priors.txt similarity index 100% rename from tensorflow/contrib/lite/examples/android/assets/box_priors.txt rename to tensorflow/contrib/lite/examples/android/app/src/main/assets/box_priors.txt diff --git a/tensorflow/contrib/lite/examples/android/assets/coco_labels_list.txt b/tensorflow/contrib/lite/examples/android/app/src/main/assets/coco_labels_list.txt similarity index 100% rename from tensorflow/contrib/lite/examples/android/assets/coco_labels_list.txt rename to tensorflow/contrib/lite/examples/android/app/src/main/assets/coco_labels_list.txt diff --git a/tensorflow/contrib/lite/examples/android/assets/conv_actions_labels.txt b/tensorflow/contrib/lite/examples/android/app/src/main/assets/conv_actions_labels.txt similarity index 100% rename from tensorflow/contrib/lite/examples/android/assets/conv_actions_labels.txt rename to tensorflow/contrib/lite/examples/android/app/src/main/assets/conv_actions_labels.txt diff --git a/tensorflow/contrib/lite/examples/android/assets/labels_mobilenet_quant_v1_224.txt b/tensorflow/contrib/lite/examples/android/app/src/main/assets/labels_mobilenet_quant_v1_224.txt similarity index 100% rename from tensorflow/contrib/lite/examples/android/assets/labels_mobilenet_quant_v1_224.txt rename to tensorflow/contrib/lite/examples/android/app/src/main/assets/labels_mobilenet_quant_v1_224.txt diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/AutoFitTextureView.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/AutoFitTextureView.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/AutoFitTextureView.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/AutoFitTextureView.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/CameraActivity.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraActivity.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/CameraActivity.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraActivity.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraConnectionFragment.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraConnectionFragment.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/Classifier.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/Classifier.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/Classifier.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/Classifier.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/ClassifierActivity.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ClassifierActivity.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/ClassifierActivity.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ClassifierActivity.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/DetectorActivity.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/DetectorActivity.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/LegacyCameraConnectionFragment.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/LegacyCameraConnectionFragment.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/OverlayView.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/OverlayView.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/OverlayView.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/OverlayView.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/RecognitionScoreView.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognitionScoreView.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/RecognitionScoreView.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognitionScoreView.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/RecognizeCommands.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognizeCommands.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/RecognizeCommands.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognizeCommands.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/ResultsView.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ResultsView.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/ResultsView.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ResultsView.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/SpeechActivity.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/SpeechActivity.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/SpeechActivity.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/SpeechActivity.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteImageClassifier.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteImageClassifier.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteImageClassifier.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteImageClassifier.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/AssetUtils.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/AssetUtils.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/AssetUtils.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/AssetUtils.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/BorderedText.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/BorderedText.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/BorderedText.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/BorderedText.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/ImageUtils.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/ImageUtils.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/ImageUtils.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/ImageUtils.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/Logger.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Logger.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/Logger.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Logger.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/Size.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Size.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/Size.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Size.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/SplitTimer.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/SplitTimer.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/SplitTimer.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/SplitTimer.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/MultiBoxTracker.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/MultiBoxTracker.java diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/ObjectTracker.java similarity index 100% rename from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java rename to tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/ObjectTracker.java diff --git a/tensorflow/contrib/lite/examples/android/res/animator/color_animation.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/animator/color_animation.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/animator/color_animation.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/animator/color_animation.xml diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_action_info.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_action_info.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_action_info.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_launcher.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/tile.9.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/tile.9.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-hdpi/tile.9.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/tile.9.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_action_info.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_action_info.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_action_info.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_launcher.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_action_info.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_action_info.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_action_info.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_launcher.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_action_info.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_action_info.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_action_info.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_launcher.png similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_launcher.png rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_launcher.png diff --git a/tensorflow/contrib/lite/examples/android/res/drawable/border.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable/border.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/drawable/border.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/drawable/border.xml diff --git a/tensorflow/contrib/lite/examples/android/res/layout/activity_camera.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_camera.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/layout/activity_camera.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_camera.xml diff --git a/tensorflow/contrib/lite/examples/android/res/layout/activity_speech.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_speech.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/layout/activity_speech.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_speech.xml diff --git a/tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment.xml diff --git a/tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment_stylize.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_stylize.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment_stylize.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_stylize.xml diff --git a/tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment_tracking.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_tracking.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment_tracking.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_tracking.xml diff --git a/tensorflow/contrib/lite/examples/android/res/layout/list_text_item.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/list_text_item.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/layout/list_text_item.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/layout/list_text_item.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-sw600dp/template-dimens.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-dimens.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-sw600dp/template-dimens.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-dimens.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-sw600dp/template-styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-sw600dp/template-styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-styles.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-v11/styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-v11/styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/styles.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-v11/template-styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/template-styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-v11/template-styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/template-styles.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-v14/styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v14/styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-v14/styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-v14/styles.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-v21/base-colors.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-colors.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-v21/base-colors.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-colors.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values-v21/base-template-styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-template-styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values-v21/base-template-styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-template-styles.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/attrs.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/attrs.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/attrs.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/attrs.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/base-strings.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/base-strings.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/base-strings.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/base-strings.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/colors.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/colors.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/colors.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/colors.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/strings.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/strings.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/strings.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/strings.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/styles.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/template-dimens.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-dimens.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/template-dimens.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-dimens.xml diff --git a/tensorflow/contrib/lite/examples/android/res/values/template-styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-styles.xml similarity index 100% rename from tensorflow/contrib/lite/examples/android/res/values/template-styles.xml rename to tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-styles.xml diff --git a/tensorflow/contrib/lite/examples/android/build.gradle b/tensorflow/contrib/lite/examples/android/build.gradle index 0d4de358156a5d139e35cc542b8d36ab24e763b9..a47fa4bbf6730c7d1269737564381c8464224713 100644 --- a/tensorflow/contrib/lite/examples/android/build.gradle +++ b/tensorflow/contrib/lite/examples/android/build.gradle @@ -1,52 +1,23 @@ -apply plugin: 'com.android.application' +// Top-level build file where you can add configuration options common to all sub-projects/modules. -android { - compileSdkVersion 26 - buildToolsVersion "26.0.1" - defaultConfig { - applicationId "org.tensorflow.lite.demo" - minSdkVersion 15 - targetSdkVersion 26 - versionCode 1 - versionName "1.0" - testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" - - // Remove this block. - jackOptions { - enabled true - } - } - lintOptions { - abortOnError false - } - buildTypes { - release { - minifyEnabled false - proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' - } - } - aaptOptions { - noCompress "tflite" +buildscript { + repositories { + jcenter() } + dependencies { + classpath 'com.android.tools.build:gradle:3.0.1' - compileOptions { - sourceCompatibility JavaVersion.VERSION_1_8 - targetCompatibility JavaVersion.VERSION_1_8 + // NOTE: Do not place your application dependencies here; they belong + // in the individual module build.gradle files } } -repositories { - maven { - url 'https://google.bintray.com/tensorflow' +allprojects { + repositories { + jcenter() } } -dependencies { - compile fileTree(dir: 'libs', include: ['*.jar']) - androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { - exclude group: 'com.android.support', module: 'support-annotations' - }) - compile 'org.tensorflow:tensorflow-lite:+' - - testCompile 'junit:junit:4.12' +task clean(type: Delete) { + delete rootProject.buildDir } diff --git a/tensorflow/contrib/lite/examples/android/settings.gradle b/tensorflow/contrib/lite/examples/android/settings.gradle new file mode 100644 index 0000000000000000000000000000000000000000..e7b4def49cb53d9aa04228dd3edb14c9e635e003 --- /dev/null +++ b/tensorflow/contrib/lite/examples/android/settings.gradle @@ -0,0 +1 @@ +include ':app' diff --git a/tensorflow/contrib/lite/examples/label_image/BUILD b/tensorflow/contrib/lite/examples/label_image/BUILD index 9322e186a280e932a2441ab16ac8579d9ab67ee2..c61445114ecc6dfbe4f2b6ab666b28a8aa746be3 100644 --- a/tensorflow/contrib/lite/examples/label_image/BUILD +++ b/tensorflow/contrib/lite/examples/label_image/BUILD @@ -53,19 +53,18 @@ cc_library( ], ) -# TODO(ahentz): Test disabled as it has a memory leek from read_bmp -# cc_test( -# name = "label_image_test", -# srcs = [ -# "get_top_n.h", -# "get_top_n_impl.h", -# "label_image_test.cc", -# ], -# data = [ -# "testdata/grace_hopper.bmp", -# ], -# deps = [ -# ":bitmap_helpers", -# "//testing/base/public:gunit", -# ], -# ) +cc_test( + name = "label_image_test", + srcs = [ + "get_top_n.h", + "get_top_n_impl.h", + "label_image_test.cc", + ], + data = [ + "testdata/grace_hopper.bmp", + ], + deps = [ + ":bitmap_helpers", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc index 0b38cd38c83927c65d251b9356301b6bef7521f2..2735d1f5ea4e2a104f71a3a6f874d9acb2f48142 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc @@ -28,8 +28,9 @@ limitations under the License. namespace tflite { namespace label_image { -uint8_t* decode_bmp(const uint8_t* input, int row_size, uint8_t* const output, - int width, int height, int channels, bool top_down) { +std::vector decode_bmp(const uint8_t* input, int row_size, int width, + int height, int channels, bool top_down) { + std::vector output(height * width * channels); for (int i = 0; i < height; i++) { int src_pos; int dst_pos; @@ -66,12 +67,11 @@ uint8_t* decode_bmp(const uint8_t* input, int row_size, uint8_t* const output, } } } - return output; } -uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, - int* channels, Settings* s) { +std::vector read_bmp(const std::string& input_bmp_name, int* width, + int* height, int* channels, Settings* s) { int begin, end; std::ifstream file(input_bmp_name, std::ios::in | std::ios::binary); @@ -87,14 +87,15 @@ uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, if (s->verbose) LOG(INFO) << "len: " << len << "\n"; - const uint8_t* img_bytes = new uint8_t[len]; + std::vector img_bytes(len); file.seekg(0, std::ios::beg); - file.read((char*)img_bytes, len); + file.read(reinterpret_cast(img_bytes.data()), len); const int32_t header_size = - *(reinterpret_cast(img_bytes + 10)); - *width = *(reinterpret_cast(img_bytes + 18)); - *height = *(reinterpret_cast(img_bytes + 22)); - const int32_t bpp = *(reinterpret_cast(img_bytes + 28)); + *(reinterpret_cast(img_bytes.data() + 10)); + *width = *(reinterpret_cast(img_bytes.data() + 18)); + *height = *(reinterpret_cast(img_bytes.data() + 22)); + const int32_t bpp = + *(reinterpret_cast(img_bytes.data() + 28)); *channels = bpp / 8; if (s->verbose) @@ -110,10 +111,9 @@ uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, bool top_down = (*height < 0); // Decode image, allocating tensor once the image size is known - uint8_t* output = new uint8_t[abs(*height) * *width * *channels]; const uint8_t* bmp_pixels = &img_bytes[header_size]; - return decode_bmp(bmp_pixels, row_size, output, *width, abs(*height), - *channels, top_down); + return decode_bmp(bmp_pixels, row_size, *width, abs(*height), *channels, + top_down); } } // namespace label_image diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h index 97343dde6b31694e5b2de20b35a7083fb8fe4a0e..5fc75b1f7274c14d49e4a26d6ce4902c037afa6b 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h @@ -22,8 +22,8 @@ limitations under the License. namespace tflite { namespace label_image { -uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, - int* channels, Settings* s); +std::vector read_bmp(const std::string& input_bmp_name, int* width, + int* height, int* channels, Settings* s); template void resize(T* out, uint8_t* in, int image_height, int image_width, diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc index 966fcd2a31fd4d4ff2c3e91633550a8effa81ee8..86d7d1cc4a625243791d5e7d5b746526a58efb6d 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.cc +++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc @@ -138,8 +138,8 @@ void RunInference(Settings* s) { int image_width = 224; int image_height = 224; int image_channels = 3; - uint8_t* in = read_bmp(s->input_bmp_name, &image_width, &image_height, - &image_channels, s); + std::vector in = read_bmp(s->input_bmp_name, &image_width, + &image_height, &image_channels, s); int input = interpreter->inputs()[0]; if (s->verbose) LOG(INFO) << "input: " << input << "\n"; @@ -168,12 +168,12 @@ void RunInference(Settings* s) { switch (interpreter->tensor(input)->type) { case kTfLiteFloat32: s->input_floating = true; - resize(interpreter->typed_tensor(input), in, image_height, - image_width, image_channels, wanted_height, wanted_width, - wanted_channels, s); + resize(interpreter->typed_tensor(input), in.data(), + image_height, image_width, image_channels, wanted_height, + wanted_width, wanted_channels, s); break; case kTfLiteUInt8: - resize(interpreter->typed_tensor(input), in, + resize(interpreter->typed_tensor(input), in.data(), image_height, image_width, image_channels, wanted_height, wanted_width, wanted_channels, s); break; diff --git a/tensorflow/contrib/lite/examples/label_image/label_image_test.cc b/tensorflow/contrib/lite/examples/label_image/label_image_test.cc index ce35483f76e8f40ced79e1ee30774c62d0eba94e..de7de21f7741d3d46cb96e793e8bc4bfb21384fe 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image_test.cc +++ b/tensorflow/contrib/lite/examples/label_image/label_image_test.cc @@ -27,20 +27,20 @@ namespace label_image { TEST(LabelImageTest, GraceHopper) { std::string lena_file = - "tensorflow/contrib/lite/examples/label_image/testdata/grace_hopper.bmp"; + "tensorflow/contrib/lite/examples/label_image/testdata/" + "grace_hopper.bmp"; int height, width, channels; Settings s; - uint8_t *data; - - data = read_bmp(lena_file, &width, &height, &channels, &s); + std::vector input = + read_bmp(lena_file, &width, &height, &channels, &s); ASSERT_EQ(height, 606); ASSERT_EQ(width, 517); ASSERT_EQ(channels, 3); - uint8_t *out = new uint8_t[606 * 517 * 3]; - downsize(out, data, 606, 517, 3, 214, 214, 3, &s); - ASSERT_EQ(out[0], 0x15); - ASSERT_EQ(out[214 * 214 * 3 - 1], 0x12); + std::vector output(606 * 517 * 3); + resize(output.data(), input.data(), 606, 517, 3, 214, 214, 3, &s); + ASSERT_EQ(output[0], 0x15); + ASSERT_EQ(output[214 * 214 * 3 - 1], 0x11); } TEST(LabelImageTest, GetTopN) { diff --git a/tensorflow/contrib/lite/examples/minimal/BUILD b/tensorflow/contrib/lite/examples/minimal/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..b403628d6c457ce3fb67eac3675fd7bb9187deab --- /dev/null +++ b/tensorflow/contrib/lite/examples/minimal/BUILD @@ -0,0 +1,27 @@ +# Description: +# TensorFlow Lite minimal example. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts") + +tf_cc_binary( + name = "minimal", + srcs = [ + "minimal.cc", + ], + linkopts = tflite_linkopts() + select({ + "//tensorflow:android": [ + "-pie", # Android 5.0 and later supports only PIE + "-lm", # some builtin ops, e.g., tanh, need -lm + ], + "//conditions:default": [], + }), + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ], +) diff --git a/tensorflow/contrib/lite/examples/minimal/minimal.cc b/tensorflow/contrib/lite/examples/minimal/minimal.cc index 106e3b027055b67092f653c6bcdc4827b56bdbaa..8b65cde7b79fde19280ad778ea874c64b01d169a 100644 --- a/tensorflow/contrib/lite/examples/minimal/minimal.cc +++ b/tensorflow/contrib/lite/examples/minimal/minimal.cc @@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/lite/model.h" +#include #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" -#include +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/optional_debug_tools.h" // This is an example that is minimal to read a model // from disk and perform inference. There is no data being loaded @@ -29,23 +30,22 @@ limitations under the License. using namespace tflite; -#define TFLITE_MINIMAL_CHECK(x) \ - if(!(x)) { \ - fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \ - exit(1); \ +#define TFLITE_MINIMAL_CHECK(x) \ + if (!(x)) { \ + fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \ + exit(1); \ } - -int main(int argc, char *argv[]) { +int main(int argc, char* argv[]) { if(argc != 2) { - fprintf(stderr, "Usage: %s \n"); + fprintf(stderr, "minimal \n"); return 1; } const char* filename = argv[1]; // Load model - std::unique_ptr model - = tflite::FlatBufferModel::BuildFromFile(filename); + std::unique_ptr model = + tflite::FlatBufferModel::BuildFromFile(filename); TFLITE_MINIMAL_CHECK(model != nullptr); // Build the interpreter @@ -57,12 +57,16 @@ int main(int argc, char *argv[]) { // Allocate tensor buffers. TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk); + printf("=== Pre-invoke Interpreter State ===\n"); + tflite::PrintInterpreterState(interpreter.get()); // Fill input buffers // TODO(user): Insert code to fill input tensors // Run inference TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk); + printf("\n\n=== Post-invoke Interpreter State ===\n"); + tflite::PrintInterpreterState(interpreter.get()); // Read output buffers // TODO(user): Insert getting data out code. diff --git a/tensorflow/contrib/lite/g3doc/ops_versioning.md b/tensorflow/contrib/lite/g3doc/ops_versioning.md new file mode 100644 index 0000000000000000000000000000000000000000..bd2f797e6c5b05f52bec9fc34f1b8011aca70330 --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/ops_versioning.md @@ -0,0 +1,206 @@ +# TensorFlow Lite Ops Versioning + +This document describes TensorFlow Lite's op versioning schema. Op +versioning enables developers to add new functionalities and parameters into +existing ops. In addition, it guarantees the following: + +* Backward compatibility: New TensorFlow Lite implementation should + handle an old model file. +* Forward compatibility: Old TensorFlow Lite implementation should + handle a new model file produced by new version of TOCO, as long as no new + features are used. +* Forward in-compatibility detection: If an old TensorFlow Lite implementation + reads a new model that contains a new version of an op which isn't + supported, it should report the error. + +## Example: Adding Dilation into Convolution + +The remainder of this document explains op versioning in TFLite by showing how +to add dilation parameters to the convolution operation. + +Knowledge of dilation is not required to understand this document. Note that: + +* 2 new integer parameters will be added: `dilation_width_factor` and + `dilation_height_factor`. +* Old convolution kernels that don't support dilation are equivalent to + setting the dilation factors to 1. + +### Change FlatBuffer Schema + +To add new parameters into an op, change the options table in +`lite/schema/schema.fbs`. + +For example, the options table of convolution looks like this: + +``` +table Conv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; +} +``` + +When adding new parameters: + +* Add comments indicating which parameters are supported by which version. +* When the new implementation gets the default values for newly added + parameters, it should work exactly the same as the old implementation. + +The table will be like this after the new parameters are added: + +``` +table Conv2DOptions { + // Parameters supported by version 1: + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; + + // Parameters supported by version 2: + dilation_width_factor:int = 1; + dilation_height_factor:int = 1; +} +``` + +### Change C Structures and Kernel Implementation + +In TensorFlow Lite, the kernel implementation is decoupled from +FlatBuffer definition. The kernels read the parameter from C structures defined +in `lite/builtin_op_data.h`. + +The original convolution parameter is as follows: + +``` +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + TfLiteFusedActivation activation; +} TfLiteConvParams; +``` + +As with the FlatBuffer schema, add comments indicating which parameters are +supported starting from which version. The result is seen below: + +``` +typedef struct { + // Parameters supported by version 1: TfLitePadding padding; int + stride_width; + int stride_height; + TfLiteFusedActivation activation; + + // Parameters supported by version 2: + int dilation_width_factor; + int dilation_height_factor; +} TfLiteConvParams; +``` + +Please also change the kernel implementation to read the newly added parameters +from the C structures. The details are omitted here. + +### Change the FlatBuffer Reading Code + +The logic to read FlatBuffer and produce C structure is in `lite/model.cc`. + +Update the file to handle the new parameters, as shown below: + +``` +case BuiltinOperator_CONV_2D: { + TfLiteConvParams* params = MallocPOD(); + if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) { + params->padding = parse_padding(conv_params->padding()); + params->stride_width = conv_params->stride_w(); + params->stride_height = conv_params->stride_h(); + params->activation = + parse_activation(conv_params->fused_activation_function()); + params->dilation_width_factor = conv_params->dilation_width_factor(); + params->dilation_height_factor = conv_params->dilation_height_factor(); + } + *builtin_data = reinterpret_cast(params); + break; +} +``` + +It's not required to check the op version here. When the new implementation +reads an old model file where dilation factors are missing, it will use 1 as +the default value, and the new kernel will work consistently with the old +kernel. + +### Change Kernel Registration + +The MutableOpResolver (defined in `lite/op_resolver.h`) provides a few functions +to register op kernels. The minimum and maximum version are 1 by default: + +``` +void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration, + int min_version = 1, int max_version = 1); +void AddCustom(const char* name, TfLiteRegistration* registration, + int min_version = 1, int max_version = 1); +``` + +The built-in ops are registered in `lite/kernels/register.cc`. In this example, +we implemented a new op kernel which can handle `Conv2D` version 1 and 2, so we +need to change this line: + +``` +AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D()); +``` + +to: + +``` +AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D(), 1, 2); +``` + +### Change TOCO TFLite exporter + +The last step is to make TOCO populate the minimum version that's required to +execute the op. In this example, it means: + +* Populate version=1 when dilation factors are all 1. +* Populate version=2 otherwise. + +To do this, you need to override `GetVersion` function for the operator class in +`lite/toco/tflite/operator.cc`. + +For ops with only one version, the `GetVersion` function is defined as: + +``` +int GetVersion(const Operator& op) const override { return 1; } +``` + +When supporting multiple versions, check the parameters and determine the +version for the op, as shown in the following example: + +``` +int GetVersion(const Operator& op) const override { + const auto& conv_op = static_cast(op); + if (conv_op.dilation_width_factor != 1 || + conv_op.dilation_height_factor != 1) { + return 2; + } + return 1; +} +``` + +### Delegation Implementation + +TensorFlow Lite provides a delegation API which enables delegating ops to +hardware backends. In Delegate's `Prepare` function, check if the version +is supported for every node in Delegation code. + +``` +const int kMinVersion = 1; +TfLiteNode* node; +TfLiteRegistration; +context->GetNodeAndRegistration(context, node_index, &node, ®istration); + +if (registration->version > kMinVersion) { + // Reject the node if the version isn't supported. +} +``` + +This is required even if the delegation only supports version 1 ops, so the +delegation can detect incompatibility when getting a higher version op. + diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index b2f6444e9e1dc37b5195d857e9620bf725e3a5b2..45104c141945a451351257e9bdbb43c0ad328258 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -95,11 +95,7 @@ Here is a list of TensorFlow operations that are usually removed from the graph: * [tf.divide](https://www.tensorflow.org/api_docs/python/tf/divide) * [tf.fake_quant_with_min_max_args](https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_args) * [tf.fake_quant_with_min_max_vars](https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_vars) -* [tf.greater](https://www.tensorflow.org/api_docs/python/tf/greater) -* [tf.greater_equal](https://www.tensorflow.org/api_docs/python/tf/greater_equal) * [tf.identity](https://www.tensorflow.org/api_docs/python/tf/identity) -* [tf.less](https://www.tensorflow.org/api_docs/python/tf/less) -* [tf.less_equal](https://www.tensorflow.org/api_docs/python/tf/less_equal) * [tf.maximum](https://www.tensorflow.org/api_docs/python/tf/maximum) * [tf.minimum](https://www.tensorflow.org/api_docs/python/tf/minimum) * [tf.multiply](https://www.tensorflow.org/api_docs/python/tf/multiply) @@ -257,6 +253,19 @@ Options { } ``` +**EQUAL** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is + equal to the corresponding element of the second tensor. +} +``` + **EXP** ``` @@ -420,6 +429,17 @@ Outputs { } ``` +**LOG** + +``` +Inputs { + 0: a tensor +} +Outputs { + 0: a tensor equivalent to log(input) +} +``` + **LOG_SOFTMAX** ``` @@ -503,6 +523,19 @@ Options { } ``` +**NOT_EQUAL** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is not + equal to the corresponding element of the second tensor. +} +``` + **RELU** ``` @@ -551,6 +584,31 @@ Options { } ``` +**RSQRT** + +``` +Inputs { + 0: a tensor +} +Outputs { + 0: result of computing element-wise reciprocal square root of the input tensor +} +``` + +**SHAPE** + +``` +Inputs { + 0: a tensor +} +Outputs { + 0: a 1D tensor representing the shape of the input tensor +} +Options { + out_type: the output type of the op (int32 or int64). Defaults to int32. +} +``` + **SLICE** ``` @@ -637,6 +695,17 @@ Options { } ``` +**SQRT** + +``` +Inputs { + 0: a tensor +} +Outputs { + 0: result of computing element-wise square root of the input tensor +} +``` + **SQUEEZE** ``` diff --git a/tensorflow/contrib/lite/graph_info.h b/tensorflow/contrib/lite/graph_info.h index 313af5fb7574b42bcdd53b4baad06e4ccfb34053..77268d7aebe9ebfb33b9f35b319d34e6de8324ee 100644 --- a/tensorflow/contrib/lite/graph_info.h +++ b/tensorflow/contrib/lite/graph_info.h @@ -46,6 +46,9 @@ class GraphInfo { // Returns the indices of the output tensors. virtual const std::vector& outputs() const = 0; + + // Returns the indices of the variable tensors. + virtual const std::vector& variables() const = 0; }; // Represents a subgraph of a TensorFlow Lite graph. diff --git a/tensorflow/contrib/lite/graph_info_test.cc b/tensorflow/contrib/lite/graph_info_test.cc index ea38b43993fef71c6820c7a978351d92d5420287..89a8f36b416b5dec54c1e374cdcdae3ab9ab0cde 100644 --- a/tensorflow/contrib/lite/graph_info_test.cc +++ b/tensorflow/contrib/lite/graph_info_test.cc @@ -45,6 +45,7 @@ class SimpleTestGraph : public GraphInfo { TfLiteTensor* tensor(size_t index) override { return &tensors_[index]; } const std::vector& inputs() const override { return inputs_; } const std::vector& outputs() const override { return outputs_; } + const std::vector& variables() const override { return variables_; } void AddNode(const std::vector& inputs, const std::vector& outputs) { @@ -67,6 +68,7 @@ class SimpleTestGraph : public GraphInfo { std::vector tensors_; std::vector inputs_; std::vector outputs_; + std::vector variables_; }; // Partition a graph to generate a list of subgraphs. This wraps the API call diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index ebb0aedc2001a86b7fcff67ef8703b5e4a845818..57b2c0f32b6c07083bb88aa9b81fcd7f71dbc672 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -82,6 +82,9 @@ class InterpreterInfo : public GraphInfo { const std::vector& outputs() const override { return interpreter_->outputs(); } + const std::vector& variables() const override { + return interpreter_->variables(); + } public: Interpreter* interpreter_; @@ -302,6 +305,13 @@ TfLiteStatus Interpreter::SetOutputs(std::vector outputs) { return kTfLiteOk; } +TfLiteStatus Interpreter::SetVariables(std::vector variables) { + TF_LITE_ENSURE_OK(&context_, CheckTensorIndices("variables", variables.data(), + variables.size())); + variables_ = std::move(variables); + return kTfLiteOk; +} + TfLiteStatus Interpreter::CheckTensorIndices(const char* label, const int* indices, int length) { // Making sure kOptionalTensor is not re-defined to something other than -1. @@ -334,6 +344,9 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims, case kTfLiteFloat32: *bytes = sizeof(float) * count; break; + case kTfLiteInt16: + *bytes = sizeof(int16_t) * count; + break; case kTfLiteInt32: *bytes = sizeof(int32_t) * count; break; @@ -347,9 +360,9 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims, *bytes = sizeof(bool) * count; break; default: - ReportError( - &context_, - "Only float32, int32, int64, uint8, bool supported currently."); + ReportError(&context_, + "Only float32, int16, int32, int64, uint8, bool supported " + "currently."); return kTfLiteError; } return kTfLiteOk; @@ -367,6 +380,7 @@ TfLiteStatus Interpreter::AllocateTensors() { } TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors()); + if (state_ == kStateUninvokable) { state_ = kStateInvokable; } @@ -375,6 +389,25 @@ TfLiteStatus Interpreter::AllocateTensors() { return kTfLiteOk; } +// TODO(ycling): Consider to provide other functions to initialize variable +// tensors to non-zero values. +TfLiteStatus Interpreter::ResetVariableTensorsToZero() { + for (auto& tensor : tensors_) { + if (!tensor.is_variable) { + continue; + } + + // Variable tensors have to be `kTfLiteArenaRwPersistent`, and must be + // allocated after the initial `PrepareOpsAndTensors()` is called. + TF_LITE_ENSURE_EQ(&context_, tensor.allocation_type, + kTfLiteArenaRwPersistent); + TF_LITE_ENSURE(&context_, tensor.data.raw != nullptr); + + memset(tensor.data.raw, 0, tensor.bytes); + } + return kTfLiteOk; +} + TfLiteStatus Interpreter::AddNodeWithParameters( const std::vector& inputs, const std::vector& outputs, const char* init_data, size_t init_data_size, void* builtin_data, @@ -572,9 +605,17 @@ TfLiteStatus Interpreter::Invoke() { } EnsureTensorsVectorCapacity(); + tensor_resized_since_op_invoke_ = false; if (OpInvoke(registration, &node) == kTfLiteError) { status = kTfLiteError; } + + // Force execution prep for downstream ops if the latest op triggered the + // resize of a dynamic tensor. + if (tensor_resized_since_op_invoke_ && + HasDynamicTensor(context_, node.outputs)) { + next_execution_plan_index_to_prepare_ = execution_plan_index + 1; + } } if (!allow_buffer_handle_output_) { @@ -687,7 +728,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly( state_ = kStateUninvokable; TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims), quantization, const_cast(buffer), bytes, - kTfLiteMmapRo, allocation, &tensor); + kTfLiteMmapRo, allocation, false, &tensor); } return kTfLiteOk; } @@ -698,7 +739,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly( // to Interpreter. TfLiteStatus Interpreter::SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, const size_t rank, - const int* dims, TfLiteQuantizationParams quantization) { + const int* dims, TfLiteQuantizationParams quantization, bool is_variable) { if (state_ == kStateInvokableAndImmutable) { ReportError( &context_, @@ -716,11 +757,23 @@ TfLiteStatus Interpreter::SetTensorParametersReadWrite( TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims, rank, &required_bytes)); } + + TfLiteAllocationType allocation_type = kTfLiteArenaRw; + if (type == kTfLiteString) { + if (is_variable) { + // We don't have a real use case for string variable tensor. + ReportError(&context_, "String variable tensor isn't supported."); + return kTfLiteError; + } + allocation_type = kTfLiteDynamic; + } else if (is_variable) { + allocation_type = kTfLiteArenaRwPersistent; + } + TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims), quantization, - /*buffer=*/nullptr, required_bytes, - type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw, - nullptr, &context_.tensors[tensor_index]); + /*buffer=*/nullptr, required_bytes, allocation_type, + nullptr, is_variable, &context_.tensors[tensor_index]); return kTfLiteOk; } @@ -736,7 +789,10 @@ TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor, TfLiteIntArray* new_size) { // Note that in theory we could resize kTfLiteArenaRwPersistent tensors too. if (tensor->allocation_type == kTfLiteArenaRw || - tensor->allocation_type == kTfLiteDynamic) { + tensor->allocation_type == kTfLiteDynamic || + tensor->allocation_type == kTfLiteArenaRwPersistent) { + tensor_resized_since_op_invoke_ |= + TfLiteIntArrayEqual(tensor->dims, new_size) == 0; if (tensor->type != kTfLiteString) { size_t bytesRequired; TfLiteStatus status = BytesRequired(tensor->type, new_size->data, diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 7315d8360680ca0d3c405dc80b593762275815ee..6b36bfc11f2d15368bbfb08708ebe2e88a597d62 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -39,6 +39,10 @@ constexpr TfLiteType typeToTfLiteType() { return kTfLiteInt32; } template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteInt16; +} +template <> constexpr TfLiteType typeToTfLiteType() { return kTfLiteInt64; } @@ -118,6 +122,11 @@ class Interpreter { // interpreter. TfLiteStatus SetOutputs(std::vector outputs); + // Provide a list of tensor indexes that are variable tensors. + // Each index is bound check and this modifies the consistent_ flag of the + // interpreter. + TfLiteStatus SetVariables(std::vector variables); + // Adds a node with the given parameters and returns the index of the new // node in `node_index` (optionally). Interpreter will take ownership of // `builtin_data` and destroy it with `free`. Ownership of 'init_data' @@ -160,13 +169,15 @@ class Interpreter { // to Interpreter. inline TfLiteStatus SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, - const std::vector& dims, TfLiteQuantizationParams quantization) { + const std::vector& dims, TfLiteQuantizationParams quantization, + bool is_variable = false) { return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(), - dims.data(), quantization); + dims.data(), quantization, is_variable); } TfLiteStatus SetTensorParametersReadWrite( int tensor_index, TfLiteType type, const char* name, const size_t rank, - const int* dims, TfLiteQuantizationParams quantization); + const int* dims, TfLiteQuantizationParams quantization, + bool is_variable = false); // Functions to access tensor data @@ -182,6 +193,9 @@ class Interpreter { // Read only access to list of outputs. const std::vector& outputs() const { return outputs_; } + // Read only access to list of variable tensors. + const std::vector& variables() const { return variables_; } + // Return the name of a given output. The given index must be between 0 and // outputs().size(). const char* GetOutputName(int index) const { @@ -379,6 +393,10 @@ class Interpreter { allow_buffer_handle_output_ = allow_buffer_handle_output; } + // Reset all variable tensors to zero. + // WARNING: This is an experimental API and subject to change. + TfLiteStatus ResetVariableTensorsToZero(); + private: // Give 'op_reg' a chance to initialize itself using the contents of // 'buffer'. @@ -541,6 +559,9 @@ class Interpreter { // interpreter. std::vector outputs_; + // Array of indices representing the tensors that are variable tensors. + std::vector variables_; + // The error reporter delegate that tflite will forward queries errors to. ErrorReporter* error_reporter_; @@ -572,6 +593,11 @@ class Interpreter { bool allow_buffer_handle_output_ = false; + // Tracking bit for whether a tensor was resized in the course of an op + // invocation. This is a useful hint to ensure that dynamic tensor outputs + // trigger downstream reallocation after op invocation. + bool tensor_resized_since_op_invoke_ = false; + // Profiler for this interpreter instance. profiling::Profiler* profiler_; }; diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index 453c1ada1cf6263be14a3b170f209e3a30580cc3..21cdf87d1e421868d1b62c5e23c2481cfbb4c989 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -23,6 +23,12 @@ limitations under the License. #include "tensorflow/contrib/lite/testing/util.h" namespace tflite { +namespace ops { +namespace builtin { +TfLiteRegistration* Register_PADV2(); +TfLiteRegistration* Register_NEG(); +} // namespace builtin +} // namespace ops namespace { // Make an interpreter that has no tensors and no nodes @@ -106,10 +112,9 @@ TEST(BasicInterpreter, CheckAllocate) { TfLiteType type; size_t size; } cases[] = { - {kTfLiteFloat32, sizeof(float)}, - {kTfLiteInt32, sizeof(int32_t)}, - {kTfLiteUInt8, sizeof(uint8_t)}, - {kTfLiteInt64, sizeof(int64_t)}, + {kTfLiteFloat32, sizeof(float)}, {kTfLiteInt32, sizeof(int32_t)}, + {kTfLiteUInt8, sizeof(uint8_t)}, {kTfLiteInt64, sizeof(int64_t)}, + {kTfLiteInt16, sizeof(int16_t)}, }; for (auto test : cases) { @@ -134,6 +139,7 @@ TEST(BasicInterpreter, CheckResize) { const int32_t int32s[] = {-3, -4}; const uint8_t uint8s[] = {3, 4}; const int64_t int64s[] = {6, -7}; + const int16_t int16s[] = {8, -9}; struct { TfLiteType type; @@ -144,6 +150,7 @@ TEST(BasicInterpreter, CheckResize) { {kTfLiteInt32, sizeof(int32_t), reinterpret_cast(int32s)}, {kTfLiteUInt8, sizeof(uint8_t), reinterpret_cast(uint8s)}, {kTfLiteInt64, sizeof(int64_t), reinterpret_cast(int64s)}, + {kTfLiteInt16, sizeof(int16_t), reinterpret_cast(int16s)}, }; for (auto test : cases) { @@ -179,10 +186,8 @@ TEST(BasicInterpreter, CheckAlignment) { struct { TfLiteType type; } cases[] = { - {kTfLiteFloat32}, - {kTfLiteInt32}, - {kTfLiteUInt8}, - {kTfLiteInt64}, + {kTfLiteFloat32}, {kTfLiteInt32}, {kTfLiteUInt8}, + {kTfLiteInt64}, {kTfLiteInt16}, }; for (auto test : cases) { @@ -211,7 +216,7 @@ TEST(BasicInterpreter, CheckArenaAllocation) { TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; std::vector sizes{2048, 4096, 1023, 2047, 1021, - 2047, 1023, 2046, 1021, 2048}; + 2047, 1023, 2046, 0, 2048}; for (int i = 0; i < sizes.size(); ++i) { interpreter.SetTensorParametersReadWrite(i, kTfLiteUInt8, "", {sizes[i]}, quant); @@ -228,6 +233,7 @@ TEST(BasicInterpreter, CheckArenaAllocation) { ASSERT_EQ(interpreter.tensor(0)->data.raw, interpreter.tensor(4)->data.raw); ASSERT_EQ(interpreter.tensor(1)->data.raw, interpreter.tensor(7)->data.raw); + ASSERT_EQ(interpreter.tensor(8)->data.raw, nullptr); ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(1)->data.raw); ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(1)->data.raw); @@ -314,6 +320,18 @@ TEST(BasicInterpreter, ResizingTensors) { EXPECT_EQ(tensor->bytes, 8 * sizeof(float)); ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.ResizeInputTensor(t, {}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 1 * sizeof(float)); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + ASSERT_EQ(interpreter.ResizeInputTensor(t, {0}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 0); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 0}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 0); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + // TODO(ahentz): We shouldn't have to force reallocation, but // ResizeInputTensor doesn't realloc dynamic tensors. Also note that // TfLiteTensorRealloc(tensor->bytes, tensor) is a no-op. @@ -603,6 +621,59 @@ TEST(BasicInterpreter, TestUnsupportedDelegateFunctions) { EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteError); } +TEST(BasicInterpreter, DynamicTensorsResizeDescendants) { + // Assemble a graph with a node that has dynamically sized output (via the + // pad op), followed by a node with a standard element-wise op (negate). + Interpreter interpreter; + interpreter.AddTensors(4); + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({3}); + TfLiteQuantizationParams quant; + interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {2, 2, 1, 1}, + quant); + interpreter.SetTensorParametersReadWrite(1, kTfLiteInt32, "", {4, 2}, quant); + interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {}, quant); + interpreter.SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {}, quant); + + TfLiteRegistration* pad_op = tflite::ops::builtin::Register_PADV2(); + TfLiteRegistration* neg_op = tflite::ops::builtin::Register_NEG(); + interpreter.AddNodeWithParameters({0, 1}, {2}, nullptr, 0, nullptr, pad_op); + interpreter.AddNodeWithParameters({2}, {3}, nullptr, 0, nullptr, neg_op); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + // Configure [[2,2],[4,4]] padding and execute the graph. + interpreter.typed_tensor(1)[0] = 2; + interpreter.typed_tensor(1)[1] = 2; + interpreter.typed_tensor(1)[2] = 2; + interpreter.typed_tensor(1)[3] = 2; + interpreter.typed_tensor(1)[4] = 0; + interpreter.typed_tensor(1)[5] = 0; + interpreter.typed_tensor(1)[6] = 0; + interpreter.typed_tensor(1)[7] = 0; + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); + + // Both the output and intermediate tensor sizes should reflect the output + // from the dynamic pad operation. + ASSERT_EQ(interpreter.tensor(2)->bytes, sizeof(float) * 6 * 6); + ASSERT_EQ(interpreter.tensor(3)->bytes, sizeof(float) * 6 * 6); + + // Now configure [[4,4],[6,6]] padding and execute the graph. + interpreter.typed_tensor(1)[0] = 4; + interpreter.typed_tensor(1)[1] = 4; + interpreter.typed_tensor(1)[2] = 6; + interpreter.typed_tensor(1)[3] = 6; + interpreter.typed_tensor(1)[4] = 0; + interpreter.typed_tensor(1)[5] = 0; + interpreter.typed_tensor(1)[6] = 0; + interpreter.typed_tensor(1)[7] = 0; + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); + + // Again, the output and intermediate tensor sizes should reflect the *new* + // resize from the latest pad operation. + ASSERT_EQ(interpreter.tensor(2)->bytes, sizeof(float) * 10 * 14); + ASSERT_EQ(interpreter.tensor(3)->bytes, sizeof(float) * 10 * 14); +} + TEST(InterpreterTensorsCapacityTest, TestWithinHeadroom) { Interpreter interpreter; ASSERT_EQ(interpreter.AddTensors(Interpreter::kTensorsReservedCapacity), diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/contrib/lite/java/demo/README.md index 2e818f728ef208d30b0eeb27ffd7e3fa0c7c1a2d..e3cea19e1683ac2680521bce66d1328e4b2caf1c 100644 --- a/tensorflow/contrib/lite/java/demo/README.md +++ b/tensorflow/contrib/lite/java/demo/README.md @@ -1,5 +1,14 @@ # TF Lite Android App +## Building in Android Studio with TensorFlow Lite AAR from JCenter. +The build.gradle is configured to use TensorFlow Lite's nightly build. + +If you see a build error related to compatibility with Tensorflow Lite's Java API (example: method X is +undefined for type Interpreter), there has likely been a backwards compatible +change to the API. You will need to pull new app code that's compatible with the +nightly build and may need to first wait a few days for our external and internal +code to merge. + ## Building from Source with Bazel 1. Follow the [Bazel steps for the TF Demo App](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#bazel): diff --git a/tensorflow/contrib/lite/java/demo/app/build.gradle b/tensorflow/contrib/lite/java/demo/app/build.gradle index b76eaad8bb91224805d16b3d6f7c3274c9feb90c..44ea2dcd908644bcfc637f71573ce722adaf6935 100644 --- a/tensorflow/contrib/lite/java/demo/app/build.gradle +++ b/tensorflow/contrib/lite/java/demo/app/build.gradle @@ -52,7 +52,43 @@ dependencies { compile 'com.android.support:support-annotations:25.3.1' compile 'com.android.support:support-v13:25.2.0' - compile 'org.tensorflow:tensorflow-lite:+' + compile 'org.tensorflow:tensorflow-lite:0.0.0-nightly' testCompile 'junit:junit:4.12' } + +def modelDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip" +def localCache = "build/intermediates/mobilenet_v1_224_android_quant_2017_11_08.zip" +def targetFolder = "src/main/assets" + +task downloadModel(type: DownloadUrlTask) { + doFirst { + println "Downloading ${modelDownloadUrl}" + } + sourceUrl = "${modelDownloadUrl}" + target = file("${localCache}") +} + +task unzipModel(type: Copy, dependsOn: 'downloadModel') { + doFirst { + println "Unzipping ${localCache}" + } + from zipTree("${localCache}") + into "${targetFolder}" +} + +// Ensure the model file is downloaded and extracted before every build +preBuild.dependsOn unzipModel + +class DownloadUrlTask extends DefaultTask { + @Input + String sourceUrl + + @OutputFile + File target + + @TaskAction + void download() { + ant.get(src: sourceUrl, dest: target) + } +} \ No newline at end of file diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD index d6fbef9cc938993b283103984307ab51e609dd6e..220d6c2159b56f6349e93132418fa0f6c69d1ab3 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD +++ b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD @@ -1,3 +1,5 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_binary") + package(default_visibility = ["//visibility:private"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/contrib/lite/java/ovic/BUILD index 362d93636f72205ddcda6d97fa9fae376ff211f1..f232b00045cf1df6a31ada80af4cc5885a4c0099 100644 --- a/tensorflow/contrib/lite/java/ovic/BUILD +++ b/tensorflow/contrib/lite/java/ovic/BUILD @@ -1,6 +1,8 @@ # Description: # OVIC Benchmarker Java API. +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD index 83974f4b337baedebaf9c9ffc0a03501418a3e36..a8d751ade26adc358e130138381eab9956f2d848 100644 --- a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD +++ b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD @@ -1,3 +1,5 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_binary") + # Sample app for OVIC benchmarking. licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index 2ae6c516b03ef4292667bbd944c73d2eeaf82db3..80de88b6a1cd75b033e116f76f5612ee66e48f03 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -311,8 +311,30 @@ final class NativeInterpreterWrapper implements AutoCloseable { return DataType.fromNumber(type).toStringName(); } + /** + * Gets the quantization zero point of an output. + * + * @throws IllegalArgumentExeption if the output index is invalid. + */ + int getOutputQuantizationZeroPoint(int index) { + return getOutputQuantizationZeroPoint(interpreterHandle, index); + } + + /** + * Gets the quantization scale of an output. + * + * @throws IllegalArgumentExeption if the output index is invalid. + */ + float getOutputQuantizationScale(int index) { + return getOutputQuantizationScale(interpreterHandle, index); + } + private static native int getOutputDataType(long interpreterHandle, int outputIdx); + private static native int getOutputQuantizationZeroPoint(long interpreterHandle, int outputIdx); + + private static native float getOutputQuantizationScale(long interpreterHandle, int outputIdx); + private static final int ERROR_BUFFER_SIZE = 512; private long errorHandle; diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index 1fb6997fb9ba180e9a3f3a89a6d177086440c0d7..31f7b58fbc30cab9e6cb813094ea4b2627ba5cba 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -561,6 +561,38 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType( return static_cast(type); } +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputQuantizationZeroPoint( + JNIEnv* env, jclass clazz, jlong handle, jint output_idx) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return 0; + const int idx = static_cast(output_idx); + if (output_idx < 0 || output_idx >= interpreter->outputs().size()) { + throwException(env, kIllegalArgumentException, + "Failed to get %d-th output out of %d outputs", output_idx, + interpreter->outputs().size()); + return 0; + } + TfLiteTensor* target = interpreter->tensor(interpreter->outputs()[idx]); + return static_cast(target->params.zero_point); +} + +JNIEXPORT jfloat JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputQuantizationScale( + JNIEnv* env, jclass clazz, jlong handle, jint output_idx) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return 1.0f; + const int idx = static_cast(output_idx); + if (output_idx < 0 || output_idx >= interpreter->outputs().size()) { + throwException(env, kIllegalArgumentException, + "Failed to get %d-th output out of %d outputs", output_idx, + interpreter->outputs().size()); + return 1.0f; + } + TfLiteTensor* target = interpreter->tensor(interpreter->outputs()[idx]); + return static_cast(target->params.scale); +} + JNIEXPORT jboolean JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput( JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h index eaa765cb343e9764bd0ef018d636a76f4b8a13e4..128ece49811a112684dac7b36810e920eeeb7351 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h @@ -152,6 +152,28 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType( JNIEnv* env, jclass clazz, jlong handle, jint output_idx); +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JI)I + * + * Gets output quantization zero point. + */ +JNIEXPORT jint JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputQuantizationZeroPoint( + JNIEnv* env, jclass clazz, jlong handle, jint output_idx); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JI)F + * + * Gets output quantization scale. + */ +JNIEXPORT jfloat JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputQuantizationScale( + JNIEnv* env, jclass clazz, jlong handle, jint output_idx); + /* * Class: org_tensorflow_lite_NativeInterpreterWrapper * Method: diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc index 005dca0253d2c30d56a15adf6e2b371d43f50945..9e9387da86ebde7d711a7ce967461e370c95bc3e 100644 --- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc @@ -43,31 +43,27 @@ size_t writeOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type, } switch (type) { case kTfLiteFloat32: { - jfloatArray a = static_cast(array); - jfloat* values = env->GetFloatArrayElements(a, nullptr); - memcpy(dst, values, to_copy); - env->ReleaseFloatArrayElements(a, values, JNI_ABORT); + jfloatArray float_array = static_cast(array); + jfloat* float_dst = static_cast(dst); + env->GetFloatArrayRegion(float_array, 0, num_elements, float_dst); return to_copy; } case kTfLiteInt32: { - jintArray a = static_cast(array); - jint* values = env->GetIntArrayElements(a, nullptr); - memcpy(dst, values, to_copy); - env->ReleaseIntArrayElements(a, values, JNI_ABORT); + jintArray int_array = static_cast(array); + jint* int_dst = static_cast(dst); + env->GetIntArrayRegion(int_array, 0, num_elements, int_dst); return to_copy; } case kTfLiteInt64: { - jlongArray a = static_cast(array); - jlong* values = env->GetLongArrayElements(a, nullptr); - memcpy(dst, values, to_copy); - env->ReleaseLongArrayElements(a, values, JNI_ABORT); + jlongArray long_array = static_cast(array); + jlong* long_dst = static_cast(dst); + env->GetLongArrayRegion(long_array, 0, num_elements, long_dst); return to_copy; } case kTfLiteUInt8: { - jbyteArray a = static_cast(array); - jbyte* values = env->GetByteArrayElements(a, nullptr); - memcpy(dst, values, to_copy); - env->ReleaseByteArrayElements(a, values, JNI_ABORT); + jbyteArray byte_array = static_cast(array); + jbyte* byte_dst = static_cast(dst); + env->GetByteArrayRegion(byte_array, 0, num_elements, byte_dst); return to_copy; } default: { diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java index 7c00d3196fd001a288d77d4e01f0b30978d72afe..9e41cb132d8386748e24c46d846e04f158d8b4c6 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -41,6 +41,9 @@ public final class NativeInterpreterWrapperTest { private static final String BYTE_MODEL_PATH = "tensorflow/contrib/lite/java/src/testdata/uint8.bin"; + private static final String QUANTIZED_MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/quantized.bin"; + private static final String INVALID_MODEL_PATH = "tensorflow/contrib/lite/java/src/testdata/invalid_model.bin"; @@ -536,4 +539,16 @@ public final class NativeInterpreterWrapperTest { assertThat(wrapper.getOutputDataType(0)).contains("byte"); wrapper.close(); } + + @Test + public void testGetOutputQuantizationParams() { + try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH)) { + assertThat(wrapper.getOutputQuantizationZeroPoint(0)).isEqualTo(0); + assertThat(wrapper.getOutputQuantizationScale(0)).isWithin(1e-6f).of(0.0f); + } + try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(QUANTIZED_MODEL_PATH)) { + assertThat(wrapper.getOutputQuantizationZeroPoint(0)).isEqualTo(127); + assertThat(wrapper.getOutputQuantizationScale(0)).isWithin(1e-6f).of(0.25f); + } + } } diff --git a/tensorflow/contrib/lite/java/src/testdata/quantized.bin b/tensorflow/contrib/lite/java/src/testdata/quantized.bin new file mode 100644 index 0000000000000000000000000000000000000000..4062088cdf717e8752490de5c9acff35fd6af54f Binary files /dev/null and b/tensorflow/contrib/lite/java/src/testdata/quantized.bin differ diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD index b524246d436858bbf506809a38cead2897f78d93..af1d99ef41e6413d8ef2c6f478aaa8f9e3931ff8 100644 --- a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD +++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD @@ -1,6 +1,8 @@ # Description: # Internal helper function to test TF Lite API. +load("@build_bazel_rules_android//android:rules.bzl", "android_library") + package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index cf5d0b4ce9cb3c516c185f31fea12db70a2c3bdb..a77897a173fc1bd9ceb63e6918ebbfb69f6d6af1 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -142,6 +142,7 @@ cc_library( "conv.cc", "depthwise_conv.cc", "dequantize.cc", + "detection_postprocess.cc", "div.cc", "elementwise.cc", "embedding_lookup.cc", @@ -157,16 +158,17 @@ cc_library( "lsh_projection.cc", "lstm.cc", "maximum_minimum.cc", - "mean.cc", "mfcc.cc", "mul.cc", "neg.cc", "pad.cc", "pooling.cc", + "reduce.cc", "register.cc", "reshape.cc", "resize_bilinear.cc", "select.cc", + "shape.cc", "skip_gram.cc", "slice.cc", "space_to_batch_nd.cc", @@ -246,6 +248,20 @@ tf_cc_test( ], ) +tf_cc_test( + name = "detection_postprocess_test", + size = "small", + srcs = ["detection_postprocess_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) + tf_cc_test( name = "activations_test", size = "small", @@ -554,9 +570,9 @@ tf_cc_test( ) tf_cc_test( - name = "mean_test", + name = "reduce_test", size = "small", - srcs = ["mean_test.cc"], + srcs = ["reduce_test.cc"], tags = ["tflite_not_portable_ios"], deps = [ ":builtin_ops", @@ -979,6 +995,20 @@ tf_cc_test( ], ) +tf_cc_test( + name = "shape_test", + size = "small", + srcs = ["shape_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc index add36b46c0b8a4deab1e842d50194c8b99a3a20c..99f81c4a8a78ab0b2a24955d77f25ed09da13b84 100644 --- a/tensorflow/contrib/lite/kernels/activations.cc +++ b/tensorflow/contrib/lite/kernels/activations.cc @@ -84,6 +84,38 @@ TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) { &data->input_left_shift); data->input_range_radius = CalculateInputRadius(kInputIntegerBits, data->input_left_shift); + } else if (input->type == kTfLiteInt16) { + static constexpr int kInputIntegerBits = 3; + static constexpr int kOutputFractionalBits = 15; + + // These operators are implemented in fixed-point arithmetic, + // which intrinsically wants symmetric ranges (zero_point==0) + // and power-of-two scales (power-of-two is abbreviated below as POT). + // While more general support would be possible by means of rescaling, + // that would add some overhead and some loss of accuracy and wouldn't + // be used at the moment as current quantized LSTM applications are + // happy with symmetric, power-of-two-scales quantization. So we just + // implement that narrow case only for now. + + TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + + int input_scale_log2_rounded; + TF_LITE_ENSURE(context, + CheckedLog2(input->params.scale, &input_scale_log2_rounded)); + + int output_scale_log2_rounded; + TF_LITE_ENSURE( + context, CheckedLog2(output->params.scale, &output_scale_log2_rounded)); + TF_LITE_ENSURE_EQ(context, output_scale_log2_rounded, + -kOutputFractionalBits); + + data->input_left_shift = + (15 - kInputIntegerBits) + input_scale_log2_rounded; + // Support for shifts is limited until we have a parameterized version of + // SaturatingRoundingMultiplyByPOT(). + TF_LITE_ENSURE(context, data->input_left_shift >= 0); + TF_LITE_ENSURE(context, data->input_left_shift <= 1); } return context->ResizeTensor(context, output, @@ -114,6 +146,30 @@ TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) { &data->input_left_shift); data->input_range_radius = CalculateInputRadius(kInputIntegerBits, data->input_left_shift); + } else if (input->type == kTfLiteInt16) { + static constexpr int kInputIntegerBits = 3; + static constexpr int kOutputFractionalBits = 15; + + // See comments in TanhPrepare about requiring zero_point==0 + // and a power-of-two ("POT") scale. + + TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + + int input_scale_log2_rounded; + TF_LITE_ENSURE(context, + CheckedLog2(input->params.scale, &input_scale_log2_rounded)); + + int output_scale_log2_rounded; + TF_LITE_ENSURE( + context, CheckedLog2(output->params.scale, &output_scale_log2_rounded)); + TF_LITE_ENSURE_EQ(context, output_scale_log2_rounded, + -kOutputFractionalBits); + + data->input_left_shift = + (15 - kInputIntegerBits) + input_scale_log2_rounded; + // The int16 logistic implementation does not support shifting of the input. + TF_LITE_ENSURE_EQ(context, data->input_left_shift, 0); } return context->ResizeTensor(context, output, @@ -250,12 +306,19 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { for (; in < in_end; in++, out++) *out = std::tanh(*in); return kTfLiteOk; } break; + case kTfLiteInt16: { + optimized_ops::Tanh(GetTensorData(input), GetTensorShape(input), + data->input_left_shift, + GetTensorData(output), + GetTensorShape(output)); + return kTfLiteOk; + } break; case kTfLiteUInt8: { - optimized_ops::Tanh(GetTensorData(input), GetTensorDims(input), + optimized_ops::Tanh(GetTensorData(input), GetTensorShape(input), input->params.zero_point, data->input_range_radius, data->input_multiplier, data->input_left_shift, GetTensorData(output), - GetTensorDims(output)); + GetTensorShape(output)); return kTfLiteOk; } break; default: @@ -280,12 +343,18 @@ TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) { for (; in < in_end; in++, out++) *out = 1.f / (1.f + std::exp(-*in)); break; } + case kTfLiteInt16: { + optimized_ops::Logistic( + GetTensorData(input), GetTensorShape(input), + GetTensorData(output), GetTensorShape(output)); + break; + } case kTfLiteUInt8: { optimized_ops::Logistic( - GetTensorData(input), GetTensorDims(input), + GetTensorData(input), GetTensorShape(input), input->params.zero_point, data->input_range_radius, data->input_multiplier, data->input_left_shift, - GetTensorData(output), GetTensorDims(output)); + GetTensorData(output), GetTensorShape(output)); break; } default: @@ -341,26 +410,26 @@ void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output, const int batch_size = input->dims->data[0]; const int input_size = input->dims->data[1]; optimized_ops::Softmax(GetTensorData(input), - GetTensorDims({batch_size, 1, 1, input_size}), + GetTensorShape({batch_size, 1, 1, input_size}), data->input_multiplier, data->input_left_shift, data->diff_min, GetTensorData(output), - GetTensorDims({batch_size, 1, 1, input_size})); + GetTensorShape({batch_size, 1, 1, input_size})); } // Takes a 4D tensor and perform softmax along the forth dimension. void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output, TfLiteSoftmaxParams* params) { - optimized_ops::Softmax(GetTensorData(input), GetTensorDims(input), + optimized_ops::Softmax(GetTensorData(input), GetTensorShape(input), params->beta, GetTensorData(output), - GetTensorDims(output)); + GetTensorShape(output)); } void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output, TfLiteSoftmaxParams* params, OpData* data) { - optimized_ops::Softmax(GetTensorData(input), GetTensorDims(input), + optimized_ops::Softmax(GetTensorData(input), GetTensorShape(input), data->input_multiplier, data->input_left_shift, data->diff_min, GetTensorData(output), - GetTensorDims(output)); + GetTensorShape(output)); } TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { @@ -415,8 +484,8 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) { switch (input->type) { case kTfLiteFloat32: optimized_ops::LogSoftmax( - GetTensorData(input), GetTensorDims(input), - GetTensorData(output), GetTensorDims(output)); + GetTensorData(input), GetTensorShape(input), + GetTensorData(output), GetTensorShape(output)); return kTfLiteOk; default: context->ReportError(context, "Only float32 supported currently., got %d", diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc index 50a84edd475c8051a563cf8ed9fc03099829b786..587e1303da6afed1fc711100f457f1bf62b0b7e1 100644 --- a/tensorflow/contrib/lite/kernels/activations_test.cc +++ b/tensorflow/contrib/lite/kernels/activations_test.cc @@ -75,23 +75,42 @@ class FloatActivationsOpModel : public BaseActivationsOpModel { std::vector GetOutput() { return ExtractVector(output_); } }; -// TODO(ahentz): I don't quite understand the tradeoffs in the quantized -// implementation of sigmoid and software, but a tolerance of twice the output -// scale seems reasonable. We might want to change this if we have a better -// theoretical bound. +// Our fixed-point math function implementations have roughly 12 bits of +// accuracy, when specialized to 16-bit fixed-point arithmetic. +// That is purely an implementation compromise, it would have been possible +// to get closer to 16 bits of accuracy but that would be more expensive, +// and not needed for our purposes as ultimately the output is either +// immediately down-quantized to 8 bits, or will typically be at the output +// of the surrounding LSTM cell. +// So we can require roughly 2^-12 accuracy when the output is 16-bit, and +// we can more or less expect the full 2^-8 accuracy when the output is 8-bit. +// +// However, the representable output interval is often [-1, 1] (it has to be +// for tanh, and even for logistic, when we implement it in fixed-point, we +// typically have to do so on such a symmetric interval, e.g. ARM NEON only +// has signed fixed-point arithmetic (SQRDMULH)). As the width of [-1, 1] +// is 2, our representable values are often diluted by a factor of 2, whence +// the factor of 2 below. const float kQuantizedTolerance = 2 * (1. / 256); +const float kQuantizedToleranceInt16 = 2 * (1. / 4096); class QuantizedActivationsOpModel : public BaseActivationsOpModel { public: using BaseActivationsOpModel::BaseActivationsOpModel; + template void SetInput(std::initializer_list data) { - QuantizeAndPopulate(input_, data); + QuantizeAndPopulate(input_, data); } - std::vector GetOutput() { return ExtractVector(output_); } + template + + std::vector GetOutput() { + return ExtractVector(output_); + } + template std::vector GetDequantizedOutput() { - return Dequantize(ExtractVector(output_), - GetScale(output_), GetZeroPoint(output_)); + return Dequantize(ExtractVector(output_), GetScale(output_), + GetZeroPoint(output_)); } }; @@ -152,24 +171,47 @@ TEST(FloatActivationsOpTest, Tanh) { } TEST(QuantizedActivationsOpTest, Tanh) { + const float kMin = -1; + const float kMax = 127.f / 128.f; QuantizedActivationsOpModel m( BuiltinOperator_TANH, - /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -8, 8}, - /*output=*/{TensorType_UINT8, {1, 2, 4, 1}, -1, 1}); - m.SetInput({ + /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 8 * kMin, 8 * kMax}, + /*output=*/{TensorType_UINT8, {1, 2, 4, 1}, kMin, kMax}); + m.SetInput({ 0, -6, 2, 4, // -4, -2, 8, 1, // }); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( { 0.0, -0.999987, 0.964027, 0.999329, // - -0.996078, -0.96402, 0.99999, 0.76159, // + -0.999329, -0.96402, 0.99999, 0.76159, // }, - 4 * (1. / 256)))); - EXPECT_THAT(m.GetOutput(), - ElementsAreArray({128, 0, 251, 255, 0, 5, 255, 226})); + kQuantizedTolerance))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({128, 0, 251, 255, 0, 5, 255, 225})); +} + +TEST(QuantizedActivationsOpTest, TanhInt16) { + const float kMin = -1; + const float kMax = 32767.f / 32768.f; + QuantizedActivationsOpModel m( + BuiltinOperator_TANH, + /*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax}, + /*output=*/{TensorType_INT16, {1, 2, 4, 1}, kMin, kMax}); + m.SetInput({ + 0, -6, 2, 4, // + -4, -2, 8, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.0, -0.999987, 0.964027, 0.999329, // + -0.999329, -0.96402, 0.99999, 0.76159, // + }, + kQuantizedToleranceInt16))); } TEST(FloatActivationsOpTest, Sigmoid) { @@ -190,22 +232,43 @@ TEST(QuantizedActivationsOpTest, Sigmoid) { QuantizedActivationsOpModel m( BuiltinOperator_LOGISTIC, /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -10, 10}); - m.SetInput({ + m.SetInput({ 0, -6, 2, 4, // 3, -2, 10, 1, // }); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( { 0.5, 0.002473, 0.880797, 0.982014, // 0.952574, 0.119203, 0.999955, 0.731059, // }, kQuantizedTolerance))); - EXPECT_THAT(m.GetOutput(), + EXPECT_THAT(m.GetOutput(), ElementsAreArray({128, 1, 227, 251, 244, 32, 255, 188})); } +TEST(QuantizedActivationsOpTest, SigmoidInt16) { + const float kMin = -1; + const float kMax = 32767.f / 32768.f; + QuantizedActivationsOpModel m( + BuiltinOperator_LOGISTIC, + /*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax}, + /*output=*/{TensorType_INT16, {1, 2, 4, 1}, kMin, kMax}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.5, 0.002473, 0.880797, 0.982014, // + 0.952574, 0.119203, 0.999955, 0.731059, // + }, + kQuantizedToleranceInt16))); +} + TEST(FloatActivationsOpTest, Softmax4D) { FloatActivationsOpModel m(0.1, /*input=*/{TensorType_FLOAT32, {1, 2, 1, 4}}); @@ -241,12 +304,12 @@ TEST(QuantizedActivationsOpTest, Softmax4D) { QuantizedActivationsOpModel m( 0.1, /*input=*/{TensorType_UINT8, {1, 2, 1, 4}, -10, 10}); - m.SetInput({ + m.SetInput({ 0, -6, 2, 4, // depth = 0 3, -2, 10, 1, // depth = 1 }); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( { .23463, .12877, .28658, .35003, // @@ -258,21 +321,22 @@ TEST(QuantizedActivationsOpTest, Softmax4D) { QuantizedActivationsOpModel m2( 0.1, /*input=*/{TensorType_UINT8, {4, 1, 1, 2}, -10, 10}); - m2.SetInput({ + m2.SetInput({ 0, -6, // 2, 4, // 3, -2, // 10, 1, // }); m2.Invoke(); - EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( - { - 0.645656, 0.354344, // - 0.450166, 0.549834, // - 0.622459, 0.377541, // - 0.710949, 0.28905, // - }, - kQuantizedTolerance))); + EXPECT_THAT(m2.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); } TEST(FloatActivationsOpTest, Softmax2D) { @@ -309,12 +373,12 @@ TEST(FloatActivationsOpTest, Softmax2D) { TEST(QuantizedActivationsOpTest, Softmax2D) { QuantizedActivationsOpModel m(0.1, /*input=*/{TensorType_UINT8, {2, 4}, -10, 10}); - m.SetInput({ + m.SetInput({ 0, -6, 2, 4, // 3, -2, 10, 1, // }); m.Invoke(); - EXPECT_THAT(m.GetDequantizedOutput(), + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( { .23463, .12877, .28658, .35003, // @@ -325,21 +389,22 @@ TEST(QuantizedActivationsOpTest, Softmax2D) { // Same input, but a different shape. QuantizedActivationsOpModel m2(0.1, /*input=*/{TensorType_UINT8, {4, 2}, -10, 10}); - m2.SetInput({ + m2.SetInput({ 0, -6, // 2, 4, // 3, -2, // 10, 1, // }); m2.Invoke(); - EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( - { - 0.645656, 0.354344, // - 0.450166, 0.549834, // - 0.622459, 0.377541, // - 0.710949, 0.28905, // - }, - kQuantizedTolerance))); + EXPECT_THAT(m2.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); } // This contains the same test values as the Softmax test, but reference answer diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc index 7ca1e35489cba3b5d2567bc04e532fedf8a527a7..ccb957ebc52e6ce9db3fbffb0c5beca9409edcc0 100644 --- a/tensorflow/contrib/lite/kernels/add.cc +++ b/tensorflow/contrib/lite/kernels/add.cc @@ -39,6 +39,23 @@ constexpr int kOutputTensor = 0; struct OpData { bool requires_broadcast; + + // These fields are used in both the general 8-bit -> 8bit quantized path, + // and the special 16-bit -> 16bit quantized path + int input1_shift; + int input2_shift; + int32 output_activation_min; + int32 output_activation_max; + + // These fields are used only in the general 8-bit -> 8bit quantized path + int32 input1_multiplier; + int32 input2_multiplier; + int32 output_multiplier; + int output_shift; + int left_shift; + int32 input1_offset; + int32 input2_offset; + int32 output_offset; }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -52,6 +69,7 @@ void Free(TfLiteContext* context, void* buffer) { } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); OpData* data = reinterpret_cast(node->user_data); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); @@ -74,6 +92,80 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { output_size = TfLiteIntArrayCopy(input1->dims); } + if (output->type == kTfLiteUInt8) { + // 8bit -> 8bit general quantized path, with general rescalings + data->input1_offset = -input1->params.zero_point; + data->input2_offset = -input2->params.zero_point; + data->output_offset = output->params.zero_point; + data->left_shift = 20; + const double twice_max_input_scale = + 2 * std::max(input1->params.scale, input2->params.scale); + const double real_input1_multiplier = + input1->params.scale / twice_max_input_scale; + const double real_input2_multiplier = + input2->params.scale / twice_max_input_scale; + const double real_output_multiplier = + twice_max_input_scale / + ((1 << data->left_shift) * output->params.scale); + + QuantizeMultiplierSmallerThanOneExp( + real_input1_multiplier, &data->input1_multiplier, &data->input1_shift); + data->input1_shift *= -1; + + QuantizeMultiplierSmallerThanOneExp( + real_input2_multiplier, &data->input2_multiplier, &data->input2_shift); + data->input2_shift *= -1; + + QuantizeMultiplierSmallerThanOneExp( + real_output_multiplier, &data->output_multiplier, &data->output_shift); + data->output_shift *= -1; + + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + + } else if (output->type == kTfLiteInt16) { + // 16bit -> 16bit special quantized path, supporting only a rather + // narrow case of quantization parameters: zero_points must all be 0 + // ("symmetric quantization") and scales must be power-of-two (which + // we abbreviate as "POT" below). The intended use case for this path + // is in LSTM cells, where, due to the constraints of implementing + // some of the math in these LSTM cells in fixed-point arithmetic, + // we need to have such symmetric, power-of-two quantization + // (Fixed-point formats are inherently symmetric, power-of-two). + TF_LITE_ENSURE_EQ(context, input1->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, input2->params.zero_point, 0); + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + + int input1_scale_log2_rounded; + bool input1_scale_is_pot = + CheckedLog2(input1->params.scale, &input1_scale_log2_rounded); + TF_LITE_ENSURE(context, input1_scale_is_pot); + + int input2_scale_log2_rounded; + bool input2_scale_is_pot = + CheckedLog2(input2->params.scale, &input2_scale_log2_rounded); + TF_LITE_ENSURE(context, input2_scale_is_pot); + + int output_scale_log2_rounded; + bool output_scale_is_pot = + CheckedLog2(output->params.scale, &output_scale_log2_rounded); + TF_LITE_ENSURE(context, output_scale_is_pot); + + data->input1_shift = output_scale_log2_rounded - input1_scale_log2_rounded; + data->input2_shift = output_scale_log2_rounded - input2_scale_log2_rounded; + + // Shifting of one input is supported. The graph quantization should ensure + // that the other input matches the output. + TF_LITE_ENSURE(context, data->input1_shift == 0 || data->input2_shift == 0); + TF_LITE_ENSURE(context, data->input1_shift >= 0); + TF_LITE_ENSURE(context, data->input2_shift >= 0); + + CalculateActivationRangeQuantized(context, params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + return context->ResizeTensor(context, output, output_size); } @@ -107,56 +199,47 @@ void EvalAddFloat(TfLiteContext* context, TfLiteNode* node, } template -void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, - TfLiteAddParams* params, const OpData* data, - const TfLiteTensor* input1, const TfLiteTensor* input2, - TfLiteTensor* output) { - auto input1_offset = -input1->params.zero_point; - auto input2_offset = -input2->params.zero_point; - auto output_offset = output->params.zero_point; - const int left_shift = 20; - const double twice_max_input_scale = - 2 * std::max(input1->params.scale, input2->params.scale); - const double real_input1_multiplier = - input1->params.scale / twice_max_input_scale; - const double real_input2_multiplier = - input2->params.scale / twice_max_input_scale; - const double real_output_multiplier = - twice_max_input_scale / ((1 << left_shift) * output->params.scale); - - int32 input1_multiplier; - int input1_shift; - QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier, - &input1_shift); - int32 input2_multiplier; - int input2_shift; - QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier, - &input2_shift); - int32 output_multiplier; - int output_shift; - QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier, - &output_shift); - - int32 output_activation_min, output_activation_max; - CalculateActivationRangeUint8(params->activation, output, - &output_activation_min, &output_activation_max); - -#define TF_LITE_ADD(type, opname) \ - type::opname(left_shift, GetTensorData(input1), \ - GetTensorDims(input1), input1_offset, input1_multiplier, \ - input1_shift, GetTensorData(input2), \ - GetTensorDims(input2), input2_offset, input2_multiplier, \ - input2_shift, output_offset, output_multiplier, output_shift, \ - output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)); - // The quantized version of Add doesn't support activations, so we - // always use BroadcastAdd. - if (kernel_type == kReference) { - TF_LITE_ADD(reference_ops, BroadcastAdd); - } else { - TF_LITE_ADD(optimized_ops, BroadcastAdd); - } +TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteAddParams* params, const OpData* data, + const TfLiteTensor* input1, + const TfLiteTensor* input2, + TfLiteTensor* output) { + if (output->type == kTfLiteUInt8) { +#define TF_LITE_ADD(type, opname) \ + type::opname( \ + data->left_shift, GetTensorData(input1), GetTensorDims(input1), \ + data->input1_offset, data->input1_multiplier, data->input1_shift, \ + GetTensorData(input2), GetTensorDims(input2), \ + data->input2_offset, data->input2_multiplier, data->input2_shift, \ + data->output_offset, data->output_multiplier, data->output_shift, \ + data->output_activation_min, data->output_activation_max, \ + GetTensorData(output), GetTensorDims(output)); + // The quantized version of Add doesn't support activations, so we + // always use BroadcastAdd. + if (kernel_type == kReference) { + TF_LITE_ADD(reference_ops, BroadcastAdd); + } else { + TF_LITE_ADD(optimized_ops, BroadcastAdd); + } #undef TF_LITE_ADD + } else if (output->type == kTfLiteInt16) { +#define TF_LITE_ADD(type, opname) \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + data->input1_shift, GetTensorData(input2), \ + GetTensorDims(input2), data->input2_shift, \ + data->output_activation_min, data->output_activation_max, \ + GetTensorData(output), GetTensorDims(output)); + // The quantized version of Add doesn't support activations, so we + // always use BroadcastAdd. + if (kernel_type == kReference) { + TF_LITE_ADD(reference_ops, Add); + } else { + TF_LITE_ADD(optimized_ops, Add); + } +#undef TF_LITE_ADD + } + + return kTfLiteOk; } template @@ -171,12 +254,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (output->type == kTfLiteFloat32) { EvalAddFloat(context, node, params, data, input1, input2, output); - } else if (output->type == kTfLiteUInt8) { - EvalAddQuantized(context, node, params, data, input1, input2, - output); + } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) { + TF_LITE_ENSURE_OK(context, + EvalAddQuantized(context, node, params, data, + input1, input2, output)); } else { context->ReportError(context, - "Inputs and outputs not all float|uint8 types."); + "Inputs and outputs not all float|uint8|int16 types."); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/add_test.cc b/tensorflow/contrib/lite/kernels/add_test.cc index 956d05bed5162f6ce59705d59aad77ff056dda77..456a754e7ee191fa74280da7af8fa844b2ef1923 100644 --- a/tensorflow/contrib/lite/kernels/add_test.cc +++ b/tensorflow/contrib/lite/kernels/add_test.cc @@ -60,15 +60,26 @@ class QuantizedAddOpModel : public BaseAddOpModel { return Dequantize(ExtractVector(output_), GetScale(output_), GetZeroPoint(output_)); } + + std::vector GetDequantizedOutputInt16() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } }; // for quantized Add, the error shouldn't exceed 2*step -float GetTolerance(int min, int max) { +float GetTolerance(float min, float max) { float kQuantizedStep = (max - min) / 255.0; float kQuantizedTolerance = 2.0 * kQuantizedStep; return kQuantizedTolerance; } +float GetToleranceInt16(float min, float max) { + float kQuantizedStep = (max - min) / 32767.f; + float kQuantizedTolerance = 2.0 * kQuantizedStep; + return kQuantizedTolerance; +} + TEST(FloatAddOpModel, NoActivation) { FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}}, @@ -144,6 +155,31 @@ TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) { } } +TEST(QuantizedAddOpModel, QuantizedTestsNoActivationInt16) { + const float kMin = -1.f; + const float kMax = 32767.f / 32768.f; + float kQuantizedTolerance = GetToleranceInt16(kMin, kMax); + std::vector> inputs1 = { + {0.1, 0.2, 0.3, 0.4}, {-0.8, 0.2, 0.4, 0.7}, {-0.8, 0.2, 0.7, 0.3}}; + std::vector> inputs2 = { + {0.6, 0.4, 0.3, 0.1}, {0.6, 0.4, 0.5, -0.8}, {0.6, 0.4, -0.8, 0.5}}; + std::vector> results = { + {0.7, 0.6, 0.6, 0.5}, {-0.2, 0.6, 0.9, -0.1}, {-0.2, 0.6, -0.1, 0.8}}; + for (int i = 0; i < inputs1.size(); ++i) { + QuantizedAddOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_INT16, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_INT16, {}, kMin, kMax}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), inputs1[i]); + m.QuantizeAndPopulate(m.input2(), inputs2[i]); + m.Invoke(); + EXPECT_THAT( + m.GetDequantizedOutputInt16(), + ElementsAreArray(ArrayFloatNear(results[i], kQuantizedTolerance))) + << "With test number " << i; + } +} + TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU_N1_TO_1) { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); std::vector> inputs1 = {{-0.8, 0.2, 0.9, 0.7}, diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc index 3b81062cd42f04582b33ea919ef2742d3d869c22..f678f48fa5bbbcece6c5b87030d951783378d78f 100644 --- a/tensorflow/contrib/lite/kernels/comparisons.cc +++ b/tensorflow/contrib/lite/kernels/comparisons.cc @@ -23,6 +23,7 @@ namespace tflite { namespace ops { namespace builtin { namespace comparisons { +namespace { constexpr int kInputTensor1 = 0; constexpr int kInputTensor2 = 1; @@ -67,6 +68,57 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) { GetTensorData(input2), GetTensorDims(input2), \ GetTensorData(output), GetTensorDims(output)); +TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + bool requires_broadcast = !HaveSameShapes(input1, input2); + // TODO(renjieliu): Support quantized data. + switch (input1->type) { + case kTfLiteFloat32: + TF_LITE_COMPARISON(float, Equal, requires_broadcast); + break; + case kTfLiteInt32: + TF_LITE_COMPARISON(int32_t, Equal, requires_broadcast); + break; + case kTfLiteInt64: + TF_LITE_COMPARISON(int64_t, Equal, requires_broadcast); + break; + default: + context->ReportError(context, + "Does not support type %d, requires float|int", + input1->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +// TODO(renjieliu): Refactor the logic to avoid duplications. +TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + bool requires_broadcast = !HaveSameShapes(input1, input2); + // TODO(renjieliu): Support quantized data. + switch (input1->type) { + case kTfLiteFloat32: + TF_LITE_COMPARISON(float, NotEqual, requires_broadcast); + break; + case kTfLiteInt32: + TF_LITE_COMPARISON(int32_t, NotEqual, requires_broadcast); + break; + case kTfLiteInt64: + TF_LITE_COMPARISON(int64_t, NotEqual, requires_broadcast); + break; + default: + context->ReportError(context, + "Does not support type %d, requires float|int", + input1->type); + return kTfLiteError; + } + return kTfLiteOk; +} + TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); @@ -167,8 +219,22 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +} // namespace } // namespace comparisons +TfLiteRegistration* Register_EQUAL() { + static TfLiteRegistration r = { + nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::EqualEval}; + return &r; +} + +TfLiteRegistration* Register_NOT_EQUAL() { + static TfLiteRegistration r = {nullptr, nullptr, + comparisons::ComparisonPrepare, + comparisons::NotEqualEval}; + return &r; +} + TfLiteRegistration* Register_GREATER() { static TfLiteRegistration r = {nullptr, nullptr, comparisons::ComparisonPrepare, diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/contrib/lite/kernels/comparisons_test.cc index 835d238d36d1757a27119ae24b3c07232e9d3dc0..bb02e1c812fdc40bf515f1f978e9e39b5a16a4ea 100644 --- a/tensorflow/contrib/lite/kernels/comparisons_test.cc +++ b/tensorflow/contrib/lite/kernels/comparisons_test.cc @@ -21,18 +21,17 @@ limitations under the License. namespace tflite { namespace { -using ::testing::ElementsAreArray; +using ::testing::ElementsAre; -class GreaterOpModel : public SingleOpModel { +class ComparisonOpModel : public SingleOpModel { public: - GreaterOpModel(std::initializer_list input1_shape, - std::initializer_list input2_shape, - TensorType input_type) { + ComparisonOpModel(std::initializer_list input1_shape, + std::initializer_list input2_shape, + TensorType input_type, BuiltinOperator op) { input1_ = AddInput(input_type); input2_ = AddInput(input_type); output_ = AddOutput(TensorType_BOOL); - SetBuiltinOp(BuiltinOperator_GREATER, BuiltinOptions_GreaterOptions, - CreateGreaterOptions(builder_).Union()); + ConfigureBuiltinOp(op); BuildInterpreter({input1_shape, input2_shape}); } @@ -46,245 +45,313 @@ class GreaterOpModel : public SingleOpModel { int input1_; int input2_; int output_; + + void ConfigureBuiltinOp(BuiltinOperator op) { + switch (op) { + case BuiltinOperator_EQUAL: { + SetBuiltinOp(op, BuiltinOptions_EqualOptions, + CreateEqualOptions(builder_).Union()); + break; + } + case BuiltinOperator_NOT_EQUAL: { + SetBuiltinOp(op, BuiltinOptions_NotEqualOptions, + CreateNotEqualOptions(builder_).Union()); + break; + } + case BuiltinOperator_GREATER: { + SetBuiltinOp(op, BuiltinOptions_GreaterOptions, + CreateGreaterOptions(builder_).Union()); + break; + } + case BuiltinOperator_GREATER_EQUAL: { + SetBuiltinOp(op, BuiltinOptions_GreaterEqualOptions, + CreateGreaterEqualOptions(builder_).Union()); + break; + } + case BuiltinOperator_LESS: { + SetBuiltinOp(op, BuiltinOptions_LessOptions, + CreateLessOptions(builder_).Union()); + break; + } + case BuiltinOperator_LESS_EQUAL: { + SetBuiltinOp(op, BuiltinOptions_LessEqualOptions, + CreateLessEqualOptions(builder_).Union()); + break; + } + default: { FAIL() << "We shouldn't get here."; } + } + } }; -TEST(ComparisonsTest, GreaterFloat) { - GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); +TEST(ComparisonsTest, EqualFloat) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_EQUAL); model.PopulateTensor(model.input1(), {0.1, 0.9, 0.7, 0.3}); model.PopulateTensor(model.input2(), {0.1, 0.2, 0.6, 0.5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } -TEST(ComparisonsTest, GreaterInt) { - GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); +TEST(ComparisonsTest, EqualInt) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {1, 2, 7, 5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } -TEST(ComparisonsTest, GreaterBroadcast) { - GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); +TEST(ComparisonsTest, EqualBroadcast) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {7}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } -TEST(ComparisonsTest, GreaterBroadcastTwoD) { - GreaterOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); +TEST(ComparisonsTest, EqualBroadcastTwoD) { + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); model.PopulateTensor(model.input2(), {7, 1, 2, 4}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false, - false, true, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, false, false, false, + false, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } -class GreaterEqualOpModel : public SingleOpModel { - public: - GreaterEqualOpModel(std::initializer_list input1_shape, - std::initializer_list input2_shape, - TensorType input_type) { - input1_ = AddInput(input_type); - input2_ = AddInput(input_type); - output_ = AddOutput(TensorType_BOOL); - SetBuiltinOp(BuiltinOperator_GREATER_EQUAL, - BuiltinOptions_GreaterEqualOptions, - CreateGreaterEqualOptions(builder_).Union()); - BuildInterpreter({input1_shape, input2_shape}); - } +TEST(ComparisonsTest, NotEqualFloat) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_NOT_EQUAL); + model.PopulateTensor(model.input1(), {0.1, 0.9, 0.7, 0.3}); + model.PopulateTensor(model.input2(), {0.1, 0.2, 0.6, 0.5}); + model.Invoke(); - int input1() { return input1_; } - int input2() { return input2_; } + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} - std::vector GetOutput() { return ExtractVector(output_); } - std::vector GetOutputShape() { return GetTensorShape(output_); } +TEST(ComparisonsTest, NotEqualInt) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_NOT_EQUAL); + model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor(model.input2(), {1, 2, 7, 5}); + model.Invoke(); - private: - int input1_; - int input2_; - int output_; -}; + EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, NotEqualBroadcast) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_NOT_EQUAL); + model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor(model.input2(), {7}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, NotEqualBroadcastTwoD) { + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_NOT_EQUAL); + model.PopulateTensor(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); + model.PopulateTensor(model.input2(), {7, 1, 2, 4}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, true, true, true, true, true, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); +} + +TEST(ComparisonsTest, GreaterFloat) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_GREATER); + model.PopulateTensor(model.input1(), {0.1, 0.9, 0.7, 0.3}); + model.PopulateTensor(model.input2(), {0.1, 0.2, 0.6, 0.5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, GreaterInt) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_GREATER); + model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor(model.input2(), {1, 2, 7, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, GreaterBroadcast) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_GREATER); + model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor(model.input2(), {7}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, GreaterBroadcastTwoD) { + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_GREATER); + model.PopulateTensor(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); + model.PopulateTensor(model.input2(), {7, 1, 2, 4}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), + ElementsAre(false, true, true, false, false, true, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); +} TEST(ComparisonsTest, GreaterEqualFloat) { - GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_GREATER_EQUAL); model.PopulateTensor(model.input1(), {0.1, 0.9, 0.7, 0.3}); model.PopulateTensor(model.input2(), {0.1, 0.2, 0.6, 0.5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, true, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, GreaterEqualInt) { - GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_GREATER_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {1, 2, 7, 5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, GreaterEqualBroadcast) { - GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_GREATER_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {7}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, GreaterEqualBroadcastTwoD) { - GreaterEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_GREATER_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); model.PopulateTensor(model.input2(), {7, 1, 2, 4}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false, - false, true, true, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); + EXPECT_THAT(model.GetOutput(), + ElementsAre(false, true, true, false, false, true, true, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } -class LessOpModel : public SingleOpModel { - public: - LessOpModel(std::initializer_list input1_shape, - std::initializer_list input2_shape, TensorType input_type) { - input1_ = AddInput(input_type); - input2_ = AddInput(input_type); - output_ = AddOutput(TensorType_BOOL); - SetBuiltinOp(BuiltinOperator_LESS, BuiltinOptions_LessOptions, - CreateLessOptions(builder_).Union()); - BuildInterpreter({input1_shape, input2_shape}); - } - - int input1() { return input1_; } - int input2() { return input2_; } - - std::vector GetOutput() { return ExtractVector(output_); } - std::vector GetOutputShape() { return GetTensorShape(output_); } - - private: - int input1_; - int input2_; - int output_; -}; TEST(ComparisonsTest, LessFloat) { - LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_LESS); model.PopulateTensor(model.input1(), {0.1, 0.9, 0.7, 0.3}); model.PopulateTensor(model.input2(), {0.1, 0.2, 0.6, 0.5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, false, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessInt) { - LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_LESS); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {1, 2, 6, 5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessBroadcast) { - LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_LESS); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {7}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessBroadcastTwoD) { - LessOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_LESS); model.PopulateTensor(model.input1(), {-1, 9, 7, 3, 2, 4, 6, 8}); model.PopulateTensor(model.input2(), {7, 1, 2, 4}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true, - true, false, false, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, false, false, true, true, false, false, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } -class LessEqualOpModel : public SingleOpModel { - public: - LessEqualOpModel(std::initializer_list input1_shape, - std::initializer_list input2_shape, - TensorType input_type) { - input1_ = AddInput(input_type); - input2_ = AddInput(input_type); - output_ = AddOutput(TensorType_BOOL); - SetBuiltinOp(BuiltinOperator_LESS_EQUAL, BuiltinOptions_LessEqualOptions, - CreateLessEqualOptions(builder_).Union()); - BuildInterpreter({input1_shape, input2_shape}); - } - - int input1() { return input1_; } - int input2() { return input2_; } - - std::vector GetOutput() { return ExtractVector(output_); } - std::vector GetOutputShape() { return GetTensorShape(output_); } - - private: - int input1_; - int input2_; - int output_; -}; - TEST(ComparisonsTest, LessEqualFloat) { - LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_LESS_EQUAL); model.PopulateTensor(model.input1(), {0.1, 0.9, 0.7, 0.3}); model.PopulateTensor(model.input2(), {0.1, 0.2, 0.6, 0.5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessEqualInt) { - LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_LESS_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {1, 2, 7, 5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessEqualBroadcast) { - LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_LESS_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {7}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessEqualBroadcastTwoD) { - LessEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_LESS_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); model.PopulateTensor(model.input2(), {7, 1, 2, 4}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true, - true, false, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, false, false, true, true, false, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } } // namespace diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index ee42e5cdc838fac4bf9a3de15b7e95e001588907..14b399ef96eab1d5066a22a7eb95ab061e8ba2bc 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -134,7 +134,9 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context, // optimized_ops.h, in order to avoid a DCHECK(!im2col_data). data->need_im2col = (params->stride_width != 1 || params->stride_height != 1 || - filter_width != 1 || filter_height != 1); + params->dilation_width_factor != 1 || + params->dilation_height_factor != 1 || filter_width != 1 || + filter_height != 1); // If we're using the optimized multithreaded EigenTensor implementation of // convolution, it expects the filter weights to be transposed compared to // the normal TF Lite buffer format. Typical TF Lite weights are @@ -255,8 +257,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( context, input, filter, bias, output, &real_multiplier)); TF_LITE_ENSURE(context, real_multiplier < 1.0); - QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, - &data->output_shift); + QuantizeMultiplierSmallerThanOneExp( + real_multiplier, &data->output_multiplier, &data->output_shift); + data->output_shift *= -1; CalculateActivationRangeUint8(params->activation, output, &data->output_activation_min, &data->output_activation_max); diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess.cc b/tensorflow/contrib/lite/kernels/detection_postprocess.cc new file mode 100644 index 0000000000000000000000000000000000000000..e4ee5885e9882d279dc0923e758d859ab5d13ff2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/detection_postprocess.cc @@ -0,0 +1,589 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace detection_postprocess { + +// Input tensors +constexpr int kInputTensorBoxEncodings = 0; +constexpr int kInputTensorClassPredictions = 1; +constexpr int kInputTensorAnchors = 2; + +// Output tensors +constexpr int kOutputTensorDetectionBoxes = 0; +constexpr int kOutputTensorDetectionClasses = 1; +constexpr int kOutputTensorDetectionScores = 2; +constexpr int kOutputTensorNumDetections = 3; + +constexpr size_t kNumCoordBox = 4; +constexpr size_t kBatchSize = 1; + +// Object Detection model produces axis-aligned boxes in two formats: +// BoxCorner represents the upper right (xmin, ymin) and +// lower left corner (xmax, ymax). +// CenterSize represents the center (xcenter, ycenter), height and width. +// BoxCornerEncoding and CenterSizeEncoding are related as follows: +// ycenter = y / y_scale * anchor.h + anchor.y; +// xcenter = x / x_scale * anchor.w + anchor.x; +// half_h = 0.5*exp(h/ h_scale)) * anchor.h; +// half_w = 0.5*exp(w / w_scale)) * anchor.w; +// ymin = ycenter - half_h +// ymax = ycenter + half_h +// xmin = xcenter - half_w +// xmax = xcenter + half_w +struct BoxCornerEncoding { + float ymin; + float xmin; + float ymax; + float xmax; +}; + +struct CenterSizeEncoding { + float y; + float x; + float h; + float w; +}; +// We make sure that the memory allocations are contiguous with static assert. +static_assert(sizeof(BoxCornerEncoding) == sizeof(float) * kNumCoordBox, + "Size of BoxCornerEncoding is 4 float values"); +static_assert(sizeof(CenterSizeEncoding) == sizeof(float) * kNumCoordBox, + "Size of CenterSizeEncoding is 4 float values"); + +struct OpData { + int max_detections; + int max_classes_per_detection; + float non_max_suppression_score_threshold; + float intersection_over_union_threshold; + int num_classes; + CenterSizeEncoding scale_values; + // Indices of Temporary tensors + int decoded_boxes_index; + int scores_index; + int active_candidate_index; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* op_data = new OpData; + const uint8_t* buffer_t = reinterpret_cast(buffer); + const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); + op_data->max_detections = m["max_detections"].AsInt32(); + op_data->max_classes_per_detection = m["max_classes_per_detection"].AsInt32(); + op_data->non_max_suppression_score_threshold = + m["nms_score_threshold"].AsFloat(); + op_data->intersection_over_union_threshold = m["nms_iou_threshold"].AsFloat(); + op_data->num_classes = m["num_classes"].AsInt32(); + op_data->scale_values.y = m["y_scale"].AsFloat(); + op_data->scale_values.x = m["x_scale"].AsFloat(); + op_data->scale_values.h = m["h_scale"].AsFloat(); + op_data->scale_values.w = m["w_scale"].AsFloat(); + context->AddTensors(context, 1, &op_data->decoded_boxes_index); + context->AddTensors(context, 1, &op_data->scores_index); + context->AddTensors(context, 1, &op_data->active_candidate_index); + return op_data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +// TODO(chowdhery): Add to kernel_util.h +TfLiteStatus SetTensorSizes(TfLiteContext* context, TfLiteTensor* tensor, + std::initializer_list values) { + TfLiteIntArray* size = TfLiteIntArrayCreate(values.size()); + int index = 0; + for (int v : values) { + size->data[index] = v; + ++index; + } + return context->ResizeTensor(context, tensor, size); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* op_data = reinterpret_cast(node->user_data); + // Inputs: box_encodings, scores, anchors + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + const TfLiteTensor* input_class_predictions = + GetInput(context, node, kInputTensorClassPredictions); + const TfLiteTensor* input_anchors = + GetInput(context, node, kInputTensorAnchors); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_box_encodings), 3); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_class_predictions), 3); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2); + // number of detected boxes + const int num_detected_boxes = + op_data->max_detections * op_data->max_classes_per_detection; + + // Outputs: detection_boxes, detection_scores, detection_classes, + // num_detections + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4); + // Output Tensor detection_boxes: size is set to (1, num_detected_boxes, 4) + TfLiteTensor* detection_boxes = + GetOutput(context, node, kOutputTensorDetectionBoxes); + detection_boxes->type = kTfLiteFloat32; + SetTensorSizes(context, detection_boxes, + {kBatchSize, num_detected_boxes, kNumCoordBox}); + + // Output Tensor detection_classes: size is set to (1, num_detected_boxes) + TfLiteTensor* detection_classes = + GetOutput(context, node, kOutputTensorDetectionClasses); + detection_classes->type = kTfLiteFloat32; + SetTensorSizes(context, detection_classes, {kBatchSize, num_detected_boxes}); + + // Output Tensor detection_scores: size is set to (1, num_detected_boxes) + TfLiteTensor* detection_scores = + GetOutput(context, node, kOutputTensorDetectionScores); + detection_scores->type = kTfLiteFloat32; + SetTensorSizes(context, detection_scores, {kBatchSize, num_detected_boxes}); + + // Output Tensor num_detections: size is set to 1 + TfLiteTensor* num_detections = + GetOutput(context, node, kOutputTensorNumDetections); + num_detections->type = kTfLiteFloat32; + // TODO (chowdhery): Make it a scalar when available + SetTensorSizes(context, num_detections, {1}); + + // Temporary tensors + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(3); + node->temporaries->data[0] = op_data->decoded_boxes_index; + node->temporaries->data[1] = op_data->scores_index; + node->temporaries->data[2] = op_data->active_candidate_index; + + // decoded_boxes + TfLiteTensor* decoded_boxes = &context->tensors[op_data->decoded_boxes_index]; + decoded_boxes->type = kTfLiteFloat32; + decoded_boxes->allocation_type = kTfLiteArenaRw; + SetTensorSizes(context, decoded_boxes, + {input_box_encodings->dims->data[1], kNumCoordBox}); + + // scores + TfLiteTensor* scores = &context->tensors[op_data->scores_index]; + scores->type = kTfLiteFloat32; + scores->allocation_type = kTfLiteArenaRw; + SetTensorSizes(context, scores, + {input_class_predictions->dims->data[1], + input_class_predictions->dims->data[2]}); + + // active_candidate + TfLiteTensor* active_candidate = + &context->tensors[op_data->active_candidate_index]; + active_candidate->type = kTfLiteUInt8; + active_candidate->allocation_type = kTfLiteArenaRw; + SetTensorSizes(context, active_candidate, + {input_box_encodings->dims->data[1]}); + + return kTfLiteOk; +} + +class Dequantizer { + public: + Dequantizer(int zero_point, float scale) + : zero_point_(zero_point), scale_(scale) {} + float operator()(uint8 x) { + return (static_cast(x) - zero_point_) * scale_; + } + + private: + int zero_point_; + float scale_; +}; + +void DequantizeBoxEncodings(const TfLiteTensor* input_box_encodings, int idx, + float quant_zero_point, float quant_scale, + CenterSizeEncoding* box_centersize) { + const uint8* boxes = + GetTensorData(input_box_encodings) + kNumCoordBox * idx; + Dequantizer dequantize(quant_zero_point, quant_scale); + box_centersize->y = dequantize(boxes[0]); + box_centersize->x = dequantize(boxes[1]); + box_centersize->h = dequantize(boxes[2]); + box_centersize->w = dequantize(boxes[3]); +} + +template +T ReInterpretTensor(const TfLiteTensor* tensor) { + // TODO (chowdhery): check float + const float* tensor_base = tensor->data.f; + return reinterpret_cast(tensor_base); +} + +template +T ReInterpretTensor(TfLiteTensor* tensor) { + // TODO (chowdhery): check float + float* tensor_base = tensor->data.f; + return reinterpret_cast(tensor_base); +} + +TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node, + OpData* op_data) { + // Parse input tensor boxencodings + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[0], kBatchSize); + const int num_boxes = input_box_encodings->dims->data[1]; + TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[2], kNumCoordBox); + + // Decode the boxes to get (ymin, xmin, ymax, xmax) based on the anchors + CenterSizeEncoding box_centersize; + CenterSizeEncoding scale_values = op_data->scale_values; + const float quant_zero_point = + static_cast(input_box_encodings->params.zero_point); + const float quant_scale = + static_cast(input_box_encodings->params.scale); + for (int idx = 0; idx < num_boxes; ++idx) { + switch (input_box_encodings->type) { + // Quantized + case kTfLiteUInt8: + DequantizeBoxEncodings(input_box_encodings, idx, quant_zero_point, + quant_scale, &box_centersize); + break; + // Float + case kTfLiteFloat32: + box_centersize = ReInterpretTensor( + input_box_encodings)[idx]; + break; + default: + // Unsupported type. + return kTfLiteError; + } + + const TfLiteTensor* input_anchors = + GetInput(context, node, kInputTensorAnchors); + + const auto& anchor = + ReInterpretTensor(input_anchors)[idx]; + + float ycenter = box_centersize.y / scale_values.y * anchor.h + anchor.y; + float xcenter = box_centersize.x / scale_values.x * anchor.w + anchor.x; + float half_h = + 0.5f * static_cast(std::exp(box_centersize.h / scale_values.h)) * + anchor.h; + float half_w = + 0.5f * static_cast(std::exp(box_centersize.w / scale_values.w)) * + anchor.w; + TfLiteTensor* decoded_boxes = + &context->tensors[op_data->decoded_boxes_index]; + auto& box = ReInterpretTensor(decoded_boxes)[idx]; + box.ymin = ycenter - half_h; + box.xmin = xcenter - half_w; + box.ymax = ycenter + half_h; + box.xmax = xcenter + half_w; + } + return kTfLiteOk; +} + +void DecreasingPartialArgSort(const float* values, int num_values, + int num_to_sort, int* indices) { + std::iota(indices, indices + num_values, 0); + std::partial_sort( + indices, indices + num_to_sort, indices + num_values, + [&values](const int i, const int j) { return values[i] > values[j]; }); +} + +void SelectDetectionsAboveScoreThreshold(const std::vector& values, + const float threshold, + std::vector* keep_values, + std::vector* keep_indices) { + for (int i = 0; i < values.size(); i++) { + if (values[i] >= threshold) { + keep_values->emplace_back(values[i]); + keep_indices->emplace_back(i); + } + } +} + +bool ValidateBoxes(const TfLiteTensor* decoded_boxes, const int num_boxes) { + for (int i = 0; i < num_boxes; ++i) { + // ymax>=ymin, xmax>=xmin + auto& box = ReInterpretTensor(decoded_boxes)[i]; + if (box.ymin >= box.ymax || box.xmin >= box.xmax) { + return false; + } + } + return true; +} + +float ComputeIntersectionOverUnion(const TfLiteTensor* decoded_boxes, + const int i, const int j) { + auto& box_i = ReInterpretTensor(decoded_boxes)[i]; + auto& box_j = ReInterpretTensor(decoded_boxes)[j]; + const float area_i = (box_i.ymax - box_i.ymin) * (box_i.xmax - box_i.xmin); + const float area_j = (box_j.ymax - box_j.ymin) * (box_j.xmax - box_j.xmin); + if (area_i <= 0 || area_j <= 0) return 0.0; + const float intersection_ymin = std::max(box_i.ymin, box_j.ymin); + const float intersection_xmin = std::max(box_i.xmin, box_j.xmin); + const float intersection_ymax = std::min(box_i.ymax, box_j.ymax); + const float intersection_xmax = std::min(box_i.xmax, box_j.xmax); + const float intersection_area = + std::max(intersection_ymax - intersection_ymin, 0.0) * + std::max(intersection_xmax - intersection_xmin, 0.0); + return intersection_area / (area_i + area_j - intersection_area); +} + +// NonMaxSuppressionSingleClass() is O(n^2) pairwise comparison between boxes +// It assumes all boxes are good in beginning and sorts based on the scores. +// If lower-scoring box has too much overlap with a higher-scoring box, +// we get rid of the lower-scoring box. +TfLiteStatus NonMaxSuppressionSingleClassHelper( + TfLiteContext* context, TfLiteNode* node, OpData* op_data, + const std::vector& scores, std::vector* selected) { + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + const TfLiteTensor* decoded_boxes = + &context->tensors[op_data->decoded_boxes_index]; + const int num_boxes = input_box_encodings->dims->data[1]; + const int max_detections = op_data->max_detections; + const float non_max_suppression_score_threshold = + op_data->non_max_suppression_score_threshold; + const float intersection_over_union_threshold = + op_data->intersection_over_union_threshold; + // Maximum detections should be positive. + TF_LITE_ENSURE(context, (max_detections >= 0)); + // intersection_over_union_threshold should be positive + // and should be less than 1. + TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) && + (intersection_over_union_threshold <= 1.0f)); + // Validate boxes + TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes)); + + // threshold scores + std::vector keep_indices; + // TODO (chowdhery): Remove the dynamic allocation and replace it + // with temporaries, esp for std::vector + std::vector keep_scores; + SelectDetectionsAboveScoreThreshold( + scores, non_max_suppression_score_threshold, &keep_scores, &keep_indices); + + int num_scores_kept = keep_scores.size(); + std::vector sorted_indices; + sorted_indices.resize(num_scores_kept); + DecreasingPartialArgSort(keep_scores.data(), num_scores_kept, num_scores_kept, + sorted_indices.data()); + + const int num_boxes_kept = keep_scores.size(); + const int output_size = std::min(num_boxes_kept, max_detections); + selected->clear(); + TfLiteTensor* active_candidate = + &context->tensors[op_data->active_candidate_index]; + TF_LITE_ENSURE(context, (active_candidate->dims->data[0]) == num_boxes); + int num_active_candidate = num_boxes; + uint8_t* active_box_candidate = (active_candidate->data.uint8); + for (int row = 0; row < num_boxes; row++) { + active_box_candidate[row] = 1; + } + + for (int i = 0; i < num_boxes; ++i) { + if (num_active_candidate == 0 || selected->size() >= output_size) break; + if (active_box_candidate[i] == 1) { + selected->push_back(keep_indices[sorted_indices[i]]); + active_box_candidate[i] = 0; + num_active_candidate--; + } else { + continue; + } + for (int j = i + 1; j < num_boxes; ++j) { + if (active_box_candidate[j] == 1) { + float intersection_over_union = ComputeIntersectionOverUnion( + decoded_boxes, keep_indices[sorted_indices[i]], + keep_indices[sorted_indices[j]]); + + if (intersection_over_union > intersection_over_union_threshold) { + active_box_candidate[j] = 0; + num_active_candidate--; + } + } + } + } + return kTfLiteOk; +} + +// This function implements a fast version of Non Maximal Suppression for +// multiple classes where +// 1) we keep the top-k scores for each anchor and +// 2) during NMS, each anchor only uses the highest class score for sorting. +// 3) Compared to standard NMS, the worst runtime of this version is O(N^2) +// instead of O(KN^2) where N is the number of anchors and K the number of +// classes. +TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context, + TfLiteNode* node, + OpData* op_data, + const float* scores) { + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + const TfLiteTensor* decoded_boxes = + &context->tensors[op_data->decoded_boxes_index]; + + TfLiteTensor* detection_boxes = + GetOutput(context, node, kOutputTensorDetectionBoxes); + TfLiteTensor* detection_classes = + GetOutput(context, node, kOutputTensorDetectionClasses); + TfLiteTensor* detection_scores = + GetOutput(context, node, kOutputTensorDetectionScores); + TfLiteTensor* num_detections = + GetOutput(context, node, kOutputTensorNumDetections); + + const int num_boxes = input_box_encodings->dims->data[1]; + const int num_classes = op_data->num_classes; + const int max_categories_per_anchor = op_data->max_classes_per_detection; + // The row index offset is 1 if background class is included and 0 otherwise. + const int label_offset = 1; + TF_LITE_ENSURE(context, (label_offset != -1)); + TF_LITE_ENSURE(context, (max_categories_per_anchor > 0)); + const int num_classes_with_background = num_classes + label_offset; + const int num_categories_per_anchor = + std::min(max_categories_per_anchor, num_classes); + std::vector max_scores; + max_scores.resize(num_boxes); + std::vector sorted_class_indices; + sorted_class_indices.resize(num_boxes * num_classes); + for (int row = 0; row < num_boxes; row++) { + const float* box_scores = + scores + row * num_classes_with_background + label_offset; + int* class_indices = sorted_class_indices.data() + row * num_classes; + DecreasingPartialArgSort(box_scores, num_classes, num_categories_per_anchor, + class_indices); + max_scores[row] = box_scores[class_indices[0]]; + } + // Perform non-maximal suppression on max scores + std::vector selected; + NonMaxSuppressionSingleClassHelper(context, node, op_data, max_scores, + &selected); + // Allocate output tensors + int output_box_index = 0; + for (const auto& selected_index : selected) { + const float* box_scores = + scores + selected_index * num_classes_with_background + label_offset; + const int* class_indices = + sorted_class_indices.data() + selected_index * num_classes; + + for (int col = 0; col < num_categories_per_anchor; ++col) { + int box_offset = num_categories_per_anchor * output_box_index + col; + // detection_boxes + ReInterpretTensor(detection_boxes)[box_offset] = + ReInterpretTensor( + decoded_boxes)[selected_index]; + // detection_classes + detection_classes->data.f[box_offset] = class_indices[col]; + // detection_scores + detection_scores->data.f[box_offset] = box_scores[class_indices[col]]; + output_box_index++; + } + } + num_detections->data.f[0] = output_box_index; + return kTfLiteOk; +} + +void DequantizeClassPredictions(const TfLiteTensor* input_class_predictions, + const int num_boxes, + const int num_classes_with_background, + const TfLiteTensor* scores) { + float quant_zero_point = + static_cast(input_class_predictions->params.zero_point); + float quant_scale = static_cast(input_class_predictions->params.scale); + Dequantizer dequantize(quant_zero_point, quant_scale); + const uint8* scores_quant = GetTensorData(input_class_predictions); + for (int idx = 0; idx < num_boxes * num_classes_with_background; ++idx) { + scores->data.f[idx] = dequantize(scores_quant[idx]); + } +} + +TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context, + TfLiteNode* node, OpData* op_data) { + // Get the input tensors + const TfLiteTensor* input_box_encodings = + GetInput(context, node, kInputTensorBoxEncodings); + const TfLiteTensor* input_class_predictions = + GetInput(context, node, kInputTensorClassPredictions); + const int num_boxes = input_box_encodings->dims->data[1]; + const int num_classes = op_data->num_classes; + TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[0], + kBatchSize); + TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[1], num_boxes); + const int num_classes_with_background = + input_class_predictions->dims->data[2]; + + TF_LITE_ENSURE(context, (num_classes_with_background == num_classes + 1)); + + const TfLiteTensor* scores; + switch (input_class_predictions->type) { + case kTfLiteUInt8: { + TfLiteTensor* temporary_scores = &context->tensors[op_data->scores_index]; + DequantizeClassPredictions(input_class_predictions, num_boxes, + num_classes_with_background, temporary_scores); + scores = temporary_scores; + } break; + case kTfLiteFloat32: + scores = input_class_predictions; + break; + default: + // Unsupported type. + return kTfLiteError; + } + NonMaxSuppressionMultiClassFastHelper(context, node, op_data, + GetTensorData(scores)); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + // TODO(chowdhery): Generalize for any batch size + TF_LITE_ENSURE(context, (kBatchSize == 1)); + auto* op_data = reinterpret_cast(node->user_data); + // These two functions correspond to two blocks in the Object Detection model. + // In future, we would like to break the custom op in two blocks, which is + // currently not feasible because we would like to input quantized inputs + // and do all calculations in float. Mixed quantized/float calculations are + // currently not supported in TFLite. + + // This fills in temporary decoded_boxes + // by transforming input_box_encodings and input_anchors from + // CenterSizeEncodings to BoxCornerEncoding + DecodeCenterSizeBoxes(context, node, op_data); + // This fills in the output tensors + // by choosing effective set of decoded boxes + // based on Non Maximal Suppression, i.e. selecting + // highest scoring non-overlapping boxes. + NonMaxSuppressionMultiClass(context, node, op_data); + + return kTfLiteOk; +} +} // namespace detection_postprocess + +TfLiteRegistration* Register_DETECTION_POSTPROCESS() { + static TfLiteRegistration r = {detection_postprocess::Init, + detection_postprocess::Free, + detection_postprocess::Prepare, + detection_postprocess::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e801c5ace3a9571003e804164c7a56de84d0457f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc @@ -0,0 +1,233 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace ops { +namespace custom { + +TfLiteRegistration* Register_DETECTION_POSTPROCESS(); + +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +class BaseDetectionPostprocessOpModel : public SingleOpModel { + public: + BaseDetectionPostprocessOpModel(const TensorData& input1, + const TensorData& input2, + const TensorData& input3, + const TensorData& output1, + const TensorData& output2, + const TensorData& output3, + const TensorData& output4) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + input3_ = AddInput(input3); + output1_ = AddOutput(output1); + output2_ = AddOutput(output2); + output3_ = AddOutput(output3); + output4_ = AddOutput(output4); + + flexbuffers::Builder fbb; + fbb.Map([&]() { + fbb.Int("max_detections", 3); + fbb.Int("max_classes_per_detection", 1); + fbb.Float("nms_score_threshold", 0.0); + fbb.Float("nms_iou_threshold", 0.5); + fbb.Int("num_classes", 2); + fbb.Float("y_scale", 10.0); + fbb.Float("x_scale", 10.0); + fbb.Float("h_scale", 5.0); + fbb.Float("w_scale", 5.0); + }); + fbb.Finish(); + SetCustomOp("TFLite_Detection_PostProcess", fbb.GetBuffer(), + Register_DETECTION_POSTPROCESS); + BuildInterpreter({GetShape(input1_), GetShape(input2_), GetShape(input3_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + int input3() { return input3_; } + + template + void SetInput1(std::initializer_list data) { + PopulateTensor(input1_, data); + } + + template + void SetInput2(std::initializer_list data) { + PopulateTensor(input2_, data); + } + + template + void SetInput3(std::initializer_list data) { + PopulateTensor(input3_, data); + } + + template + std::vector GetOutput1() { + return ExtractVector(output1_); + } + + template + std::vector GetOutput2() { + return ExtractVector(output2_); + } + + template + std::vector GetOutput3() { + return ExtractVector(output3_); + } + + template + std::vector GetOutput4() { + return ExtractVector(output4_); + } + + std::vector GetOutputShape1() { return GetTensorShape(output1_); } + std::vector GetOutputShape2() { return GetTensorShape(output2_); } + std::vector GetOutputShape3() { return GetTensorShape(output3_); } + std::vector GetOutputShape4() { return GetTensorShape(output4_); } + + protected: + int input1_; + int input2_; + int input3_; + int output1_; + int output2_; + int output3_; + int output4_; +}; + +TEST(DetectionPostprocessOpTest, FloatTest) { + BaseDetectionPostprocessOpModel m( + {TensorType_FLOAT32, {1, 6, 4}}, {TensorType_FLOAT32, {1, 6, 3}}, + {TensorType_FLOAT32, {6, 4}}, {TensorType_FLOAT32, {}}, + {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}, + {TensorType_FLOAT32, {}}); + + // six boxes in center-size encoding + m.SetInput1({0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, + 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}); + // class scores - two classes with background + m.SetInput2({0., .9, .8, 0., .75, .72, 0., .6, .5, 0., .93, .95, 0., + .5, .4, 0., .3, .2}); + // six anchors in center-size encoding + m.SetInput3({0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, + 0.5, 0.5, 1.0, 1.0, 0.5, 10.5, 1.0, 1.0, + 0.5, 10.5, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0}); + // Same boxes in box-corner encoding: + // { 0.0, 0.0, 1.0, 1.0, + // 0.0, 0.1, 1.0, 1.1, + // 0.0, -0.1, 1.0, 0.9, + // 0.0, 10.0, 1.0, 11.0, + // 0.0, 10.1, 1.0, 11.1, + // 0.0, 100.0, 1.0, 101.0} + m.Invoke(); + // detection_boxes + // in center-size + std::vector output_shape1 = m.GetOutputShape1(); + EXPECT_THAT(output_shape1, ElementsAre(1, 3, 4)); + EXPECT_THAT( + m.GetOutput1(), + ElementsAreArray(ArrayFloatNear( + {0.0, 10.0, 1.0, 11.0, 0.0, 0.0, 1.0, 1.0, 0.0, 100.0, 1.0, 101.0}, + 1e-1))); + // detection_classes + std::vector output_shape2 = m.GetOutputShape2(); + EXPECT_THAT(output_shape2, ElementsAre(1, 3)); + EXPECT_THAT(m.GetOutput2(), + ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-1))); + // detection_scores + std::vector output_shape3 = m.GetOutputShape3(); + EXPECT_THAT(output_shape3, ElementsAre(1, 3)); + EXPECT_THAT(m.GetOutput3(), + ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-1))); + // num_detections + std::vector output_shape4 = m.GetOutputShape4(); + EXPECT_THAT(output_shape4, ElementsAre(1)); + EXPECT_THAT(m.GetOutput4(), + ElementsAreArray(ArrayFloatNear({3.0}, 1e-1))); +} + +TEST(DetectionPostprocessOpTest, QuantizedTest) { + BaseDetectionPostprocessOpModel m( + {TensorType_UINT8, {1, 6, 4}, -1.0, 1.0}, + {TensorType_UINT8, {1, 6, 3}, 0.0, 1.0}, {TensorType_FLOAT32, {6, 4}}, + {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}, + {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}); + // six boxes in center-size encoding + std::vector> inputs1 = { + {0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}; + m.QuantizeAndPopulate(m.input1(), inputs1[0]); + // class scores - two classes with background + std::vector> inputs2 = { + {0., .9, .8, 0., .75, .72, 0., .6, .5, 0., .93, .95, 0., .5, .4, 0., .3, + .2}}; + m.QuantizeAndPopulate(m.input2(), inputs2[0]); + // six anchors in center-size encoding + m.SetInput3({0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, + 0.5, 0.5, 1.0, 1.0, 0.5, 10.5, 1.0, 1.0, + 0.5, 10.5, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0}); + m.Invoke(); + // detection_boxes + // in center-size + std::vector output_shape1 = m.GetOutputShape1(); + EXPECT_THAT(output_shape1, ElementsAre(1, 3, 4)); + EXPECT_THAT( + m.GetOutput1(), + ElementsAreArray(ArrayFloatNear( + {0.0, 10.0, 1.0, 11.0, 0.0, 0.0, 1.0, 1.0, 0.0, 100.0, 1.0, 101.0}, + 1e-1))); + // detection_classes + std::vector output_shape2 = m.GetOutputShape2(); + EXPECT_THAT(output_shape2, ElementsAre(1, 3)); + EXPECT_THAT(m.GetOutput2(), + ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-1))); + // detection_scores + std::vector output_shape3 = m.GetOutputShape3(); + EXPECT_THAT(output_shape3, ElementsAre(1, 3)); + EXPECT_THAT(m.GetOutput3(), + ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-1))); + // num_detections + std::vector output_shape4 = m.GetOutputShape4(); + EXPECT_THAT(output_shape4, ElementsAre(1)); + EXPECT_THAT(m.GetOutput4(), + ElementsAreArray(ArrayFloatNear({3.0}, 1e-1))); +} +} // namespace +} // namespace custom +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc index 0bd504695074011efd946f4c4d1f8d4854e82730..59bab3c4ecd20bf938919ca606a5933f3112f233 100644 --- a/tensorflow/contrib/lite/kernels/elementwise.cc +++ b/tensorflow/contrib/lite/kernels/elementwise.cc @@ -23,7 +23,7 @@ namespace ops { namespace builtin { namespace elementwise { -TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) { +TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); const TfLiteTensor* input = GetInput(context, node, 0); @@ -35,7 +35,8 @@ TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArrayCopy(input->dims)); } -TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { +inline TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, + float float_func(float)) { const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); switch (input->type) { @@ -44,7 +45,7 @@ TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { const float* in = GetTensorData(input); const float* in_end = in + elements; float* out = output->data.f; - for (; in < in_end; in++, out++) *out = std::sin(*in); + for (; in < in_end; in++, out++) *out = float_func(*in); return kTfLiteOk; } default: { @@ -55,14 +56,48 @@ TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { } } +TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { + return Eval(context, node, std::sin); +} + +TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) { + return Eval(context, node, std::log); +} + +TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) { + return Eval(context, node, std::sqrt); +} + +TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) { + return Eval(context, node, [](float f) { return 1.f / std::sqrt(f); }); +} + } // namespace elementwise TfLiteRegistration* Register_SIN() { - static TfLiteRegistration r = {nullptr, nullptr, elementwise::SinPrepare, + static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, elementwise::SinEval}; return &r; } +TfLiteRegistration* Register_LOG() { + static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, + elementwise::LogEval}; + return &r; +} + +TfLiteRegistration* Register_SQRT() { + static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, + elementwise::SqrtEval}; + return &r; +} + +TfLiteRegistration* Register_RSQRT() { + static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, + elementwise::RsqrtEval}; + return &r; +} + } // namespace builtin } // namespace ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/elementwise_test.cc b/tensorflow/contrib/lite/kernels/elementwise_test.cc index 412ffb04b90fbc24d232d25d2a86ce639752c3e8..ce4c602ee5c788d67701af3ecd3e023f2b25aae7 100644 --- a/tensorflow/contrib/lite/kernels/elementwise_test.cc +++ b/tensorflow/contrib/lite/kernels/elementwise_test.cc @@ -24,12 +24,13 @@ namespace { using ::testing::ElementsAreArray; -class SinOpModel : public SingleOpModel { +class ElementWiseOpModel : public SingleOpModel { public: - SinOpModel(std::initializer_list input_shape) { + ElementWiseOpModel(BuiltinOperator op, + std::initializer_list input_shape) { input_ = AddInput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp(BuiltinOperator_SIN, BuiltinOptions_NONE, 0); + SetBuiltinOp(op, BuiltinOptions_NONE, 0); BuildInterpreter({input_shape}); } @@ -42,7 +43,7 @@ class SinOpModel : public SingleOpModel { }; TEST(ElementWise, Sin) { - SinOpModel m({1, 1, 4, 1}); + ElementWiseOpModel m(BuiltinOperator_SIN, {1, 1, 4, 1}); m.PopulateTensor(m.input(), {0, 3.1415926, -3.1415926, 1}); m.Invoke(); EXPECT_THAT(m.ExtractVector(m.output()), @@ -50,6 +51,33 @@ TEST(ElementWise, Sin) { EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); } +TEST(ElementWise, Log) { + ElementWiseOpModel m(BuiltinOperator_LOG, {1, 1, 4, 1}); + m.PopulateTensor(m.input(), {1, 3.1415926, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray(ArrayFloatNear({0, 1.14473, 0, 0}))); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + +TEST(ElementWise, Sqrt) { + ElementWiseOpModel m(BuiltinOperator_SQRT, {1, 1, 4, 1}); + m.PopulateTensor(m.input(), {0, 1, 2, 4}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray(ArrayFloatNear({0, 1, 1.41421, 2}))); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + +TEST(ElementWise, Rsqrt) { + ElementWiseOpModel m(BuiltinOperator_RSQRT, {1, 1, 4, 1}); + m.PopulateTensor(m.input(), {1, 2, 4, 9}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray(ArrayFloatNear({1, 0.7071, 0.5, 0.33333}))); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc index 7539c0b30ded921df957217bebdc7b20ea4b40b4..9410bead5e7a68363d034c22fb2c0eff9f060ef1 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc @@ -24,7 +24,8 @@ limitations under the License. // Output: // Output.dim[0] == Tensor[0].dim[0], num of lookups // Output.dim[1] == Tensor[1].dim[1], num of items per row -// Each item in output is a raw bytes copy of corresponding item in input. +// Each item in output is a raw bytes copy of the corresponding item in input, +// or a dequantized value in the case of a uint8 input. // When indices are out of bound, the ops will not succeed. // @@ -69,11 +70,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return context->ResizeTensor(context, output, outputSize); } -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* output = GetOutput(context, node, 0); - const TfLiteTensor* lookup = GetInput(context, node, 0); - const TfLiteTensor* value = GetInput(context, node, 1); - +TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, + const TfLiteTensor* lookup, const TfLiteTensor* value, + TfLiteTensor* output) { const int row_size = SizeOfDimension(value, 0); const int row_bytes = value->bytes / row_size; @@ -91,6 +90,52 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, + const TfLiteTensor* lookup, const TfLiteTensor* value, + TfLiteTensor* output) { + const int row_size = SizeOfDimension(value, 0); + const double scaling_factor = 1.0 / value->params.scale; + + // col_size after we flatten tensor into 2D. + int col_size = 1; + for (int i = 1; i < NumDimensions(value); i++) { + col_size *= SizeOfDimension(value, i); + } + + for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { + int idx = lookup->data.i32[i]; + if (idx >= row_size || idx < 0) { + context->ReportError(context, "Embedding Lookup: index out of bounds."); + return kTfLiteError; + } else { + // Dequantize embedding values. + // TODO(alanchiao): refactor scalar multiply into separate function + // for ease of adding a neon equivalent if ever necessary. + for (int j = 0; j < col_size; j++) { + output->data.f[j + i * col_size] = + value->data.uint8[j + idx * col_size] * scaling_factor; + } + } + } + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* lookup = GetInput(context, node, 0); + const TfLiteTensor* value = GetInput(context, node, 1); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (value->type) { + case kTfLiteFloat32: + return EvalFloat(context, node, lookup, value, output); + case kTfLiteUInt8: + return EvalHybrid(context, node, lookup, value, output); + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } +} + } // namespace embedding_lookup TfLiteRegistration* Register_EMBEDDING_LOOKUP() { diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc index 9b501878f196216a61568bfa36e6615f4dd07478..04657fd86323ef1c58d069c06097c7665f55cc87 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc @@ -7,13 +7,14 @@ You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. +distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License +for the specific language governing permissions and limitations under the +License. ==============================================================================*/ // Unit test for TFLite Lookup op. +#include #include #include @@ -29,12 +30,13 @@ namespace { using ::testing::ElementsAreArray; -class EmbeddingLookupOpModel : public SingleOpModel { +class BaseEmbeddingLookupOpModel : public SingleOpModel { public: - EmbeddingLookupOpModel(std::initializer_list index_shape, - std::initializer_list weight_shape) { + BaseEmbeddingLookupOpModel(std::initializer_list index_shape, + std::initializer_list weight_shape, + TensorType weight_type = TensorType_FLOAT32) { input_ = AddInput(TensorType_INT32); - weight_ = AddInput(TensorType_FLOAT32); + weight_ = AddInput(weight_type); output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0); BuildInterpreter({index_shape, weight_shape}); @@ -44,6 +46,18 @@ class EmbeddingLookupOpModel : public SingleOpModel { PopulateTensor(input_, data); } + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int weight_; + int output_; +}; + +class EmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel { + public: + using BaseEmbeddingLookupOpModel::BaseEmbeddingLookupOpModel; + void Set3DWeightMatrix(const std::function& function) { TfLiteTensor* tensor = interpreter_->tensor(weight_); int rows = tensor->dims->data[0]; @@ -57,20 +71,25 @@ class EmbeddingLookupOpModel : public SingleOpModel { } } } +}; - std::vector GetOutput() { return ExtractVector(output_); } +class HybridEmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel { + public: + HybridEmbeddingLookupOpModel(std::initializer_list index_shape, + std::initializer_list weight_shape) + : BaseEmbeddingLookupOpModel(index_shape, weight_shape, + TensorType_UINT8) {} - private: - int input_; - int weight_; - int output_; + void SetWeight(std::initializer_list data) { + SymmetricQuantizeAndPopulate(weight_, data); + } }; // TODO(ahentz): write more tests that exercise the details of the op, such as // lookup errors and variable input shapes. TEST(EmbeddingLookupOpTest, SimpleTest) { EmbeddingLookupOpModel m({3}, {3, 2, 4}); - m.PopulateTensor(0, {1, 0, 2}); + m.SetInput({1, 0, 2}); m.Set3DWeightMatrix( [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); @@ -84,6 +103,69 @@ TEST(EmbeddingLookupOpTest, SimpleTest) { }))); } +TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTest) { + HybridEmbeddingLookupOpModel m({3}, {3, 8}); + m.SetInput({1, 0, 2}); + m.SetWeight({ + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + { + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }, + 7.41e-03))); +} + +TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTest) { + HybridEmbeddingLookupOpModel m({3}, {3, 2, 4}); + m.SetInput({1, 0, 2}); + m.SetWeight({ + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + { + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }, + 7.41e-03))); +} + +TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTest) { + HybridEmbeddingLookupOpModel m({3}, {3, 2, 2, 2}); + m.SetInput({1, 0, 2}); + m.SetWeight({ + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + { + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }, + 7.41e-03))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc index 989920622dff1fe246efb920e0d18efa5f8e9215..f6fc0f5b6ad12d58c541efc6eae566ab4b8327f4 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected.cc +++ b/tensorflow/contrib/lite/kernels/fully_connected.cc @@ -105,7 +105,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const int batch_size = input_size / filter->dims->data[1]; const int num_units = filter->dims->data[0]; - TF_LITE_ASSERT_EQ(input_size, batch_size * filter->dims->data[1]); + TF_LITE_ENSURE_EQ(context, input_size, batch_size * filter->dims->data[1]); if (bias) { TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0)); } @@ -118,8 +118,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( context, input, filter, bias, output, &real_multiplier)); TF_LITE_ENSURE(context, real_multiplier < 1.0); - QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, - &data->output_shift); + QuantizeMultiplierSmallerThanOneExp( + real_multiplier, &data->output_multiplier, &data->output_shift); + data->output_shift *= -1; CalculateActivationRangeUint8(params->activation, output, &data->output_activation_min, &data->output_activation_max); diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index 0a5223b23529ef80b251d5144a94c5969c5cc02c..7962fcbc9d6c839ea11d7355e955239194442e03 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -176,6 +176,40 @@ cc_library( }), ) +cc_library( + name = "legacy_optimized_base", + srcs = [], + hdrs = [ + "common.h", + "optimized/depthwiseconv_float.h", + "optimized/depthwiseconv_uint8.h", + "optimized/depthwiseconv_uint8_3x3_filter.h", + "optimized/legacy_optimized_ops.h", + "optimized/optimized_ops.h", + ], + copts = tflite_copts(), + deps = [ + ":quantization_util", + ":strided_slice_logic", + ":types", + ":legacy_reference_base", + ":round", + "//third_party/eigen3", + "@gemmlowp", + "//tensorflow/contrib/lite:builtin_op_data", + ] + select({ + ":haswell": tflite_deps_intel, + ":ios_x86_64": tflite_deps_intel, + ":k8": tflite_deps_intel, + ":x86": tflite_deps_intel, + ":x86_64": tflite_deps_intel, + ":darwin": tflite_deps_intel, + ":darwin_x86_64": tflite_deps_intel, + ":freebsd": tflite_deps_intel, + "//conditions:default": [], + }), +) + cc_library( name = "optimized", hdrs = [ @@ -273,6 +307,37 @@ cc_library( }), ) +cc_library( + name = "legacy_reference_base", + srcs = [], + hdrs = [ + "common.h", + "reference/depthwiseconv_float.h", + "reference/depthwiseconv_uint8.h", + "reference/legacy_reference_ops.h", + "reference/reference_ops.h", + ], + deps = [ + ":quantization_util", + ":round", + ":strided_slice_logic", + ":types", + "//third_party/eigen3", + "@gemmlowp", + "//tensorflow/contrib/lite:builtin_op_data", + ] + select({ + ":haswell": tflite_deps_intel, + ":ios_x86_64": tflite_deps_intel, + ":k8": tflite_deps_intel, + ":x86": tflite_deps_intel, + ":x86_64": tflite_deps_intel, + ":darwin": tflite_deps_intel, + ":darwin_x86_64": tflite_deps_intel, + ":freebsd": tflite_deps_intel, + "//conditions:default": [], + }), +) + cc_library( name = "reference", hdrs = ["tensor.h"], @@ -474,8 +539,9 @@ cc_test( ) cc_test( - name = "resize_bilinear_float_test", - srcs = ["resize_bilinear_float_test.cc"], + name = "resize_bilinear_test", + srcs = ["resize_bilinear_test.cc"], + tags = ["tflite_not_portable"], deps = [ ":optimized_base", ":reference_base", diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc index 67e3810479af005b6a8d871420d2a33101788f3d..36c25388e8bde721d7644dc83d5b7c490d37b4d3 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc @@ -63,6 +63,8 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr, // Quantize input from float to uint8 + quantization params (scaling // factor). float unused_min, unused_max; + // TODO(mirkov,raziel): replace this for-loop with a MACRO (or function) + // whichever is faster. for (int b = 0; b < batch_size; ++b) { const int offset = b * input_size; tensor_utils::SymmetricQuantizeFloats( @@ -147,6 +149,7 @@ void LstmStep( input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, input_gate_scratch, /*result_stride=*/1); } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1); @@ -161,8 +164,7 @@ void LstmStep( if (!use_cifg) { tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, input_gate_scratch, - /*result_stride=*/1); + n_batch, input_gate_scratch, /*result_stride=*/1); } tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr, @@ -253,5 +255,263 @@ void LstmStep( output_state_ptr); } +// TODO(alanchiao): move this to tensor_utils. +void VectorMultiply(const int8_t* vector, const int v_size, const float scale, + float* result) { + for (int i = 0; i < v_size; ++i) { + *result++ = scale * *vector++; + } +} + +void LstmStep( + const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, + float input_to_input_weights_scale, + const int8_t* input_to_forget_weights_ptr, + float input_to_forget_weights_scale, + const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale, + const int8_t* input_to_output_weights_ptr, + float input_to_output_weights_scale, + const int8_t* recurrent_to_input_weights_ptr, + float recurrent_to_input_weights_scale, + const int8_t* recurrent_to_forget_weights_ptr, + float recurrent_to_forget_weights_scale, + const int8_t* recurrent_to_cell_weights_ptr, + float recurrent_to_cell_weights_scale, + const int8_t* recurrent_to_output_weights_ptr, + float recurrent_to_output_weights_scale, + const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale, + const int8_t* cell_to_forget_weights_ptr, + float cell_to_forget_weights_scale, + const int8_t* cell_to_output_weights_ptr, + float cell_to_output_weights_scale, const float* input_gate_bias_ptr, + const float* forget_gate_bias_ptr, const float* cell_bias_ptr, + const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, + float projection_weights_scale, const float* projection_bias_ptr, + const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, + int n_output, float* input_gate_scratch, float* forget_gate_scratch, + float* cell_scratch, float* output_gate_scratch, float* scaling_factors, + float* product_scaling_factors, float* recovered_cell_weights, + int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr, + int8_t* quantized_cell_state_ptr, float* output_state_ptr, + float* cell_state_ptr, float* output_ptr_batch) { + // Since we have already checked that weights are all there or none, we can + // check the existense of only one to the get the condition. + const bool use_cifg = (input_to_input_weights_ptr == nullptr); + const bool use_peephole = (cell_to_output_weights_ptr != nullptr); + // Initialize scratch buffers with bias. + if (!use_cifg) { + tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch, + input_gate_scratch); + } + tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, + forget_gate_scratch); + tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, + cell_scratch); + tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, + output_gate_scratch); + + if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_input; + tensor_utils::SymmetricQuantizeFloats( + input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset, + &unused_min, &unused_max, &scaling_factors[b]); + } + // For each batch and cell: compute input_weight * input. + if (!use_cifg) { + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_input_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_input_weights_ptr, n_cell, n_input, + quantized_input_ptr_batch, product_scaling_factors, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_forget_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, + product_scaling_factors, n_batch, forget_gate_scratch, + /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_cell_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, + product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_output_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, + product_scaling_factors, n_batch, output_gate_scratch, + /*result_stride=*/1); + } + + if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_output; + tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output, + quantized_output_state_ptr + offset, + &unused_min, &unused_max, + &scaling_factors[b]); + } + // For each batch and cell: compute recurrent_weight * output_state. + if (!use_cifg) { + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_input_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_input_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_forget_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_forget_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + forget_gate_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_cell_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_cell_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + cell_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_output_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_output_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + output_gate_scratch, /*result_stride=*/1); + } + + // Save quantization and matmul computation for all zero input. + bool is_cell_state_all_zeros = + tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); + + // For each batch and cell: update input gate. + if (!use_cifg) { + if (use_peephole && !is_cell_state_all_zeros) { + VectorMultiply(cell_to_input_weights_ptr, n_cell, + 1. / cell_to_input_weights_scale, recovered_cell_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_cell_weights, n_cell, cell_state_ptr, n_batch, + input_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, + input_gate_scratch); + } + + // For each batch and cell: update forget gate. + if (use_peephole && !is_cell_state_all_zeros) { + VectorMultiply(cell_to_forget_weights_ptr, n_cell, + 1. / cell_to_forget_weights_scale, recovered_cell_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_cell_weights, n_cell, cell_state_ptr, n_batch, + forget_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, + forget_gate_scratch); + + // For each batch and cell: update the cell. + tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, + n_batch * n_cell, cell_state_ptr); + tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, + params->activation, cell_scratch); + if (use_cifg) { + tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, + forget_gate_scratch); + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr); + } else { + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); + } + if (params->cell_clip > 0.0) { + tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, + params->cell_clip, cell_state_ptr); + } + + is_cell_state_all_zeros = + tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); + // For each batch and cell: update the output gate. + if (use_peephole && !is_cell_state_all_zeros) { + VectorMultiply(cell_to_output_weights_ptr, n_cell, + 1. / cell_to_output_weights_scale, recovered_cell_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_cell_weights, n_cell, cell_state_ptr, n_batch, + output_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, + output_gate_scratch); + tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, + params->activation, cell_scratch); + tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, + n_batch * n_cell, output_gate_scratch); + + // For each batch: update the projection and output_state. + const bool use_projection_weight = (projection_weights_ptr != nullptr); + const bool use_projection_bias = (projection_bias_ptr != nullptr); + if (use_projection_weight) { + if (use_projection_bias) { + tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output, + n_batch, output_ptr_batch); + } else { + tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output); + } + if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_cell; + tensor_utils::SymmetricQuantizeFloats( + output_gate_scratch + offset, n_cell, + quantized_cell_state_ptr + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * projection_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr, + product_scaling_factors, n_batch, output_ptr_batch, + /*result_stride=*/1); + } + if (params->proj_clip > 0.0) { + tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, + params->proj_clip, output_ptr_batch); + } + } else { + tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, + output_ptr_batch); + } + tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output, + output_state_ptr); +} + } // namespace kernel_utils } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h index f3f42f0840fc6a35d95adb3eaa0d621cc8bad8e2..2a11b37a6069367e8232350c2fc68d4c385e14ba 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h @@ -92,6 +92,89 @@ void LstmStep( float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, float* output_ptr_batch); +// Same as above but with quantized weight matrices. In detail: +// Input of size 'n_batch * n_input': +// input_ptr_batch +// +// LSTM weights: +// Quantized input weights of size 'n_cell * n_input': +// input_to_input_weights - optional (can be nullptr) +// input_to_forget_weights +// input_to_cell_weights +// input_to_input_weights +// Quantized recurrent weights of size 'n_cell * n_output': +// recurrent_to_input_weights - optional +// recurrent_to_forget_weights +// recurrent_to_cell_weights +// recurrent_to_input_weights +// Quantized peephole weights of size 'n_cell', representing diagonal matrices. +// cell_to_input_weights - optional +// cell_to_cell_weights - optional +// cell_to_output_weights - optional +// Quantized projection weights of size 'n_output * n_cell' +// projection_weights_ptr - optional +// Weight scales (scalars) for each of the weights above. +// input_to_input_weights_scale - optional +// input_to_forget_weights_scale +// input_to_cell_weights_scale +// input_to_output_weights_scale +// recurrent_to_input_weights_scale - optional +// recurrent_to_forget_weights_scale +// recurrent_to_cell_weights_scale +// recurrent_to_output_weights_scale +// cell_to_input_weights_scale, +// cell_to_forget_weights_scale, +// cell_to_output_weights_scale, +// projection_weights_scale - optional +// Gate biases of size 'n_cell': +// input_gate_bias_ptr - optional +// forget_gate_bias_ptr +// cell_gate_bias_ptr +// output_gate_bias_ptr +// +// Temporary pre-allocated storage for quantized values: +// quantized_input_ptr_batch (same size as input_ptr_batch) +// quantized_output_state_ptr (same size as output_state_ptr) +// quantized_cell_state_ptr (same size as cell_state_ptr) +// Temporary pre-allocated storage for recovered values: +// recovered_cell_weights (same size as cell_to_*_weights) +// +// Outputs: +// output_state_ptr - size 'n_batch * n_output' +// cell_state_ptr - size 'n_batch * n_cell' +// output_ptr_batch - size 'n_batch * n_output' +void LstmStep( + const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, + float input_to_input_weights_scale, + const int8_t* input_to_forget_weights_ptr, + float input_to_forget_weights_scale, + const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale, + const int8_t* input_to_output_weights_ptr, + float input_to_output_weights_scale, + const int8_t* recurrent_to_input_weights_ptr, + float recurrent_to_input_weights_scale, + const int8_t* recurrent_to_forget_weights_ptr, + float recurrent_to_forget_weights_scale, + const int8_t* recurrent_to_cell_weights_ptr, + float recurrent_to_cell_weights_scale, + const int8_t* recurrent_to_output_weights_ptr, + float recurrent_to_output_weights_scale, + const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale, + const int8_t* cell_to_forget_weights_ptr, + float cell_to_forget_weights_scale, + const int8_t* cell_to_output_weights_ptr, + float cell_to_output_weights_scale, const float* input_gate_bias_ptr, + const float* forget_gate_bias_ptr, const float* cell_bias_ptr, + const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, + float projection_weights_scale, const float* projection_bias_ptr, + const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, + int n_output, float* input_gate_scratch, float* forget_gate_scratch, + float* cell_scratch, float* output_gate_scratch, float* scaling_factors, + float* product_scaling_factors, float* recovered_cell_weights, + int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr, + int8_t* quantized_cell_state_ptr, float* output_state_ptr, + float* cell_state_ptr, float* output_ptr_batch); + } // namespace kernel_utils } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc index b7531ea2e202cd6fe012e0fa675380775016d38f..d2f1103e14b40b81c59c8053bcdbee30c85e5c78 100644 --- a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc @@ -32,19 +32,21 @@ namespace tflite { namespace { void RunLogSoftmaxFloatReference(const uint8* input_data, - const Dims<4>& dims_common, int32 input_offset, - const double input_scale, int stride, - float beta, uint8* reference_output_data) { - const int ref_buffer_size = RequiredBufferSizeForDims(dims_common); + const RuntimeShape& shape_common, + int32 input_offset, const double input_scale, + int stride, float beta, + uint8* reference_output_data) { + const int ref_buffer_size = shape_common.FlatSize(); std::vector reference_dequant_data(ref_buffer_size); std::vector reference_output_float_data(ref_buffer_size); // Reference data generated via Dequant of input into float, and then applying // float LogSoftmax. - reference_ops::Dequantize(input_data, dims_common, input_offset, input_scale, - reference_dequant_data.data(), dims_common); - optimized_ops::LogSoftmax(reference_dequant_data.data(), dims_common, - reference_output_float_data.data(), dims_common); + reference_ops::Dequantize( + input_data, ToRuntimeDims(shape_common), input_offset, input_scale, + reference_dequant_data.data(), ToRuntimeDims(shape_common)); + optimized_ops::LogSoftmax(reference_dequant_data.data(), shape_common, + reference_output_float_data.data(), shape_common); // Work with quantized scaling for LogSoftmax, under which 255 represents 0, // and -16 gets nudged up to 0. for (int i = 0; i < ref_buffer_size; i++) { @@ -55,9 +57,9 @@ void RunLogSoftmaxFloatReference(const uint8* input_data, } void CheckOutputData(const uint8* test_output, const uint8* reference_output, - const Dims<4>& dims_common, const string& check_label, - bool be_exacting) { - const int buffer_size = RequiredBufferSizeForDims(dims_common); + const RuntimeShape& shape_common, + const string& check_label, bool be_exacting) { + const int buffer_size = shape_common.FlatSize(); // While calculating some metrics in floating point, we work with quantized // scaling. std::vector diff(buffer_size); @@ -99,15 +101,15 @@ void CheckOutputData(const uint8* test_output, const uint8* reference_output, // Runs the LogSoftmax and compares against the float reference implementation // and the quantized reference implementation. -void RunOneLogSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common, - int32 input_offset, const double input_scale, - int stride, float beta) { - const int buffer_size = RequiredBufferSizeForDims(dims_common); +void RunOneLogSoftmaxTest(const uint8* input_data, + const RuntimeShape& shape_common, int32 input_offset, + const double input_scale, int stride, float beta) { + const int buffer_size = shape_common.FlatSize(); std::vector optimized_logsoftmax_output(buffer_size); std::vector reference_float_logsoftmax_output(buffer_size); std::vector reference_quant_logsoftmax_output(buffer_size); - RunLogSoftmaxFloatReference(input_data, dims_common, input_offset, + RunLogSoftmaxFloatReference(input_data, shape_common, input_offset, input_scale, stride, beta, reference_float_logsoftmax_output.data()); @@ -116,32 +118,33 @@ void RunOneLogSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common, int32 reverse_scaling_divisor; int reverse_scaling_right_shift; static const int kScaledDiffIntegerBits = 5; - tflite::PreprocessLogSoftmaxScaling( + tflite::PreprocessLogSoftmaxScalingExp( beta, input_scale, kScaledDiffIntegerBits, &input_beta_multiplier, &input_beta_left_shift, &reverse_scaling_divisor, &reverse_scaling_right_shift); + reverse_scaling_right_shift *= -1; // diff_min has a negative value, and is used to limit the maximum magnitude // of the diffs, which are <= 0. const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits, input_beta_left_shift); - optimized_ops::LogSoftmax(input_data, dims_common, input_beta_multiplier, + optimized_ops::LogSoftmax(input_data, shape_common, input_beta_multiplier, input_beta_left_shift, reverse_scaling_divisor, reverse_scaling_right_shift, diff_min, - optimized_logsoftmax_output.data(), dims_common); + optimized_logsoftmax_output.data(), shape_common); reference_ops::LogSoftmax( - input_data, dims_common, input_beta_multiplier, input_beta_left_shift, + input_data, shape_common, input_beta_multiplier, input_beta_left_shift, reverse_scaling_divisor, reverse_scaling_right_shift, diff_min, - reference_quant_logsoftmax_output.data(), dims_common); + reference_quant_logsoftmax_output.data(), shape_common); CheckOutputData(optimized_logsoftmax_output.data(), - reference_float_logsoftmax_output.data(), dims_common, + reference_float_logsoftmax_output.data(), shape_common, "Optimized vs float reference", false); CheckOutputData(optimized_logsoftmax_output.data(), - reference_quant_logsoftmax_output.data(), dims_common, + reference_quant_logsoftmax_output.data(), shape_common, "Optimized vs quant reference", true); CheckOutputData(reference_quant_logsoftmax_output.data(), - reference_float_logsoftmax_output.data(), dims_common, + reference_float_logsoftmax_output.data(), shape_common, "Quant reference vs float reference", false); } @@ -164,13 +167,13 @@ bool TryOneUniformLogSoftmax() { const int32 input_offset = UniformRandomInt(-256, 0); static constexpr float beta = 1.0f; - Dims<4> dims_common = - MakeDimsForInference(input_depth, input_width, input_height, batch); - const int buffer_size = RequiredBufferSizeForDims(dims_common); + auto shape_common = + RuntimeShape({batch, input_height, input_width, input_depth}); + const int buffer_size = shape_common.FlatSize(); std::vector input_data(buffer_size); FillRandom(&input_data); - RunOneLogSoftmaxTest(input_data.data(), dims_common, input_offset, + RunOneLogSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale, stride, beta); return true; } @@ -202,14 +205,14 @@ bool TryOneSkyscraperLogSoftmax(bool small_depth) { const int middle_min = UniformRandomInt(0, 255); const int sides_max = UniformRandomInt(0, middle_min); - Dims<4> dims_common = - MakeDimsForInference(input_depth, input_width, input_height, batch); - const int buffer_size = RequiredBufferSizeForDims(dims_common); + auto shape_common = + RuntimeShape({batch, input_height, input_width, input_depth}); + const int buffer_size = shape_common.FlatSize(); std::vector input_data(buffer_size); FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min, sides_max); - RunOneLogSoftmaxTest(input_data.data(), dims_common, input_offset, + RunOneLogSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale, stride, beta); return true; } diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h index a7b0d805a3acd35b592a35ba4266dfff4eb992cd..4cfaa0f36defa9c1f7d4a51af243c416bf09e331 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h @@ -26,7 +26,7 @@ namespace optimized_ops { // Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on // Jetson TX-2. This compiler does not support the offsetof() macro. #if defined(__aarch64__) && !defined(GOOGLE_L4T) - +#include // clang-format gets confused with this file and ends up formatting lines to // be larger than 80 characters. Turn off here and back on at the end of the // file. diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..7816752132761d9523ffc1f45b3740c0817ed402 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h @@ -0,0 +1,324 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_ + +#include +#include + +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace optimized_ops { + +// Unoptimized reference ops: +using reference_ops::Relu1; +using reference_ops::Relu6; + +inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) { + return RuntimeShape( + {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]}); +} + +template +void L2Normalization(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + L2Normalization(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, uint8* output_data, + const Dims<4>& output_dims) { + L2Normalization(input_data, DimsToShape(input_dims), input_zero_point, + output_data, DimsToShape(output_dims)); +} + +inline void Relu(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Relu(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, float* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, filter_width, filter_height, + output_activation_min, output_activation_max, output_data, + DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int kwidth, int kheight, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, filter_width, filter_height, + output_activation_min, output_activation_max, output_data, + DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + L2Pool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, filter_width, filter_height, + output_activation_min, output_activation_max, output_data, + DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + L2Pool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + L2Pool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void Softmax(const float* input_data, const Dims<4>& input_dims, + float beta, float* output_data, + const Dims<4>& output_dims) { + Softmax(input_data, DimsToShape(input_dims), beta, output_data, + DimsToShape(output_dims)); +} + +inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, + int32 input_beta_multiplier, int32 input_beta_left_shift, + int diff_min, uint8* output_data, + const Dims<4>& output_dims) { + Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier, + input_beta_left_shift, diff_min, output_data, + DimsToShape(output_dims)); +} + +inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + LogSoftmax(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, + int32 input_multiplier, int32 input_left_shift, + int32 reverse_scaling_divisor, + int32 reverse_scaling_right_shift, int diff_min, + uint8* output_data, const Dims<4>& output_dims) { + LogSoftmax(input_data, DimsToShape(input_dims), input_multiplier, + input_left_shift, reverse_scaling_divisor, + reverse_scaling_right_shift, diff_min, output_data, + DimsToShape(output_dims)); +} + +inline void Logistic(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Logistic(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + Logistic(input_data, DimsToShape(input_dims), input_zero_point, + input_range_radius, input_multiplier, input_left_shift, output_data, + DimsToShape(output_dims)); +} + +inline void Logistic(const int16* input_data, const Dims<4>& input_dims, + int16* output_data, const Dims<4>& output_dims) { + Logistic(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Tanh(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Tanh(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + Tanh(input_data, DimsToShape(input_dims), input_zero_point, + input_range_radius, input_multiplier, input_left_shift, output_data, + DimsToShape(output_dims)); +} + +inline void Tanh(const int16* input_data, const Dims<4>& input_dims, + int input_left_shift, int16* output_data, + const Dims<4>& output_dims) { + Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data, + DimsToShape(output_dims)); +} + +} // namespace optimized_ops +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index 0ce781db59a2cff0e0c199244b867fddf98804d6..868269477e9d2097607938929d19303a38cfb5c5 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -40,16 +40,29 @@ namespace tflite { namespace optimized_ops { // Unoptimized reference ops: +using reference_ops::ArgMax; using reference_ops::BroadcastGreater; using reference_ops::BroadcastGreaterEqual; using reference_ops::BroadcastLess; using reference_ops::BroadcastLessEqual; +using reference_ops::Concatenation; +using reference_ops::DepthConcatenation; +using reference_ops::Dequantize; +using reference_ops::Div; +using reference_ops::FakeQuant; +using reference_ops::Gather; using reference_ops::Greater; using reference_ops::GreaterEqual; using reference_ops::Less; using reference_ops::LessEqual; +using reference_ops::Mean; using reference_ops::RankOneSelect; +using reference_ops::Relu1; +using reference_ops::Relu6; using reference_ops::Select; +using reference_ops::SpaceToBatchND; +using reference_ops::StridedSlice; +using reference_ops::Transpose; // TODO(b/80247582) Remove this constant. // This will be phased out as the shifts are revised with more thought. Use of a @@ -72,6 +85,12 @@ using VectorMap = typename std::conditional< Eigen::Dynamic, 1>>, Eigen::Map>>::type; +template +VectorMap MapAsVector(Scalar* data, const RuntimeShape& shape) { + const int size = shape.FlatSize(); + return VectorMap(data, size, 1); +} + template VectorMap MapAsVector(Scalar* data, const Dims& dims) { const int size = FlatSize(dims); @@ -88,6 +107,23 @@ using MatrixMap = typename std::conditional< Eigen::Dynamic, Eigen::Dynamic>>, Eigen::Map>>::type; +template +MatrixMap MapAsMatrixWithLastDimAsRows(Scalar* data, + const RuntimeShape& shape) { + const int dims_count = shape.DimensionsCount(); + const int rows = shape.Dims(dims_count - 1); + const int cols = FlatSizeSkipDim(shape, dims_count - 1); + return MatrixMap(data, rows, cols); +} + +template +MatrixMap MapAsMatrixWithFirstDimAsCols(Scalar* data, + const RuntimeShape& shape) { + const int cols = shape.Dims(0); + const int rows = FlatSizeSkipDim(shape, 0); + return MatrixMap(data, rows, cols); +} + template MatrixMap MapAsMatrixWithFirstDimAsRows(Scalar* data, const Dims& dims) { @@ -1082,10 +1118,10 @@ struct GemmlowpOutputPipeline { gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint, gemmlowp::OutputStageClamp, gemmlowp::OutputStageSaturatingCastToUint8> Pipeline; - static Pipeline Make(const int32* bias_data, int output_rows, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max) { + static Pipeline MakeExp(const int32* bias_data, int output_rows, + int32 output_offset, int32 output_multiplier, + int output_left_shift, int32 output_activation_min, + int32 output_activation_max) { ColVectorMap bias_vector(bias_data, output_rows); gemmlowp::OutputStageBiasAddition bias_addition_stage; bias_addition_stage.bias_vector = bias_vector; @@ -1093,7 +1129,7 @@ struct GemmlowpOutputPipeline { quantize_down_stage; quantize_down_stage.result_offset_after_shift = output_offset; quantize_down_stage.result_fixedpoint_multiplier = output_multiplier; - quantize_down_stage.result_shift = output_shift; + quantize_down_stage.result_shift = -output_left_shift; gemmlowp::OutputStageClamp clamp_stage; clamp_stage.min = output_activation_min; clamp_stage.max = output_activation_max; @@ -1146,8 +1182,8 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, input_data, filter_cols, batches, filter_cols); gemmlowp::MatrixMap output_matrix( output_data, output_rows, batches, output_rows); - const auto& output_pipeline = GemmlowpOutputPipeline::Make( - bias_data, output_rows, output_offset, output_multiplier, output_shift, + const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp( + bias_data, output_rows, output_offset, output_multiplier, -output_shift, output_activation_min, output_activation_max); gemmlowp::GemmWithOutputPipeline( @@ -1821,8 +1857,8 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims, // Use dimensions M and N to construct dims for indexing directly into im2col Dims<4> im2col_dims; - im2col_dims.sizes[0] = col_dims.strides[3]; - im2col_dims.sizes[1] = row_dims.strides[3]; + im2col_dims.sizes[0] = FlatSize(col_dims); + im2col_dims.sizes[1] = FlatSize(row_dims); im2col_dims.sizes[2] = 1; im2col_dims.sizes[3] = 1; ComputeStrides(&im2col_dims); @@ -1831,8 +1867,8 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims, for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { - // Each row is an output pixel. Arrange the input data into this row in - // an order we can conveniently multiply with the filter data. + // Each im2col row is an output pixel. Arrange the input data in this + // row in an order we can conveniently multiply with the filter data. int row_offset = Offset(row_dims, out_x, out_y, batch, 0); const int in_x_origin = (out_x * stride_width) - pad_width; const int in_y_origin = (out_y * stride_height) - pad_height; @@ -1848,7 +1884,7 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims, T* dst = im2col_data + Offset(im2col_dims, col_offset, row_offset, 0, 0); if ((in_x >= 0) && (in_x < input_width)) { - // Filter pixel is within the input, copy the data. + // Filter pixel is within the input, copy the input data. T const* src = input_data + Offset(input_dims, 0, in_x, in_y, batch); memcpy(dst, src, input_depth * sizeof(T)); @@ -1858,7 +1894,7 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims, } } } else { - // Filter row is outside the input, zero out the entire im2col row. + // Filter row is outside the input, zero out the entire filter row. int col_offset = Offset(col_dims, 0, 0, filter_y, 0); T* dst = im2col_data + Offset(im2col_dims, col_offset, row_offset, 0, 0); @@ -1922,7 +1958,7 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims, (void)im2col_dims; gemmlowp::ScopedProfilingLabel label("Conv"); - // A float set to 0x00000000h == 0.0f + // NB: static_cast(0x00000000h) == 0.0f const uint8 float_zero_byte = 0x00; const float* gemm_input_data = nullptr; const Dims<4>* gemm_input_dims = nullptr; @@ -2084,8 +2120,8 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, gemm_input_data, gemm_input_rows, gemm_input_cols); gemmlowp::MatrixMap output_matrix( output_data, output_rows, output_cols); - const auto& output_pipeline = GemmlowpOutputPipeline::Make( - bias_data, output_rows, output_offset, output_multiplier, output_shift, + const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp( + bias_data, output_rows, output_offset, output_multiplier, -output_shift, output_activation_min, output_activation_max); gemmlowp::GemmWithOutputPipeline( @@ -2242,8 +2278,8 @@ void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims, input_data, filter_cols, output_cols, filter_cols); gemmlowp::MatrixMap output_matrix( output_data, output_rows, output_cols, output_rows); - const auto& output_pipeline = GemmlowpOutputPipeline::Make( - bias_data, output_rows, output_offset, output_multiplier, output_shift, + const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp( + bias_data, output_rows, output_offset, output_multiplier, -output_shift, output_activation_min, output_activation_max); gemmlowp::GemmWithOutputPipeline( @@ -2330,48 +2366,25 @@ void GlobalBatchNormalization(const float* input_data, } } -inline void Relu(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { +inline void Relu(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Relu (not fused)"); - const auto input = MapAsVector(input_data, input_dims); - auto output = MapAsVector(output_data, output_dims); + const auto input = MapAsVector(input_data, input_shape); + auto output = MapAsVector(output_data, output_shape); output = input.cwiseMax(0.0f); } -inline void Relu1(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)"); - const int flat_size = MatchingFlatSize(input_dims, output_dims); - for (int i = 0; i < flat_size; ++i) { - const float val = input_data[i]; - const float upper = 1; - const float lower = -1; - const float clamped = val > upper ? upper : val < lower ? lower : val; - output_data[i] = clamped; - } -} - -inline void Relu6(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)"); - const int flat_size = MatchingFlatSize(input_dims, output_dims); - for (int i = 0; i < flat_size; ++i) { - const float val = input_data[i]; - const float upper = 6; - const float lower = 0; - const float clamped = val > upper ? upper : val < lower ? lower : val; - output_data[i] = clamped; - } -} - template -void L2Normalization(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { +void L2Normalization(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("L2Normalization"); static_assert(Ac == FusedActivationFunctionType::kNone, ""); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { float squared_l2_norm = 0; for (int c = 0; c < depth; ++c) { @@ -2387,8 +2400,9 @@ void L2Normalization(const float* input_data, const Dims<4>& input_dims, } } -inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, - int* output_shift) { +inline void GetInvSqrtQuantizedMultiplierExp(int32 input, + int32* output_inv_sqrt, + int* output_shift) { *output_shift = 11; while (input >= (1 << 29)) { input /= 4; @@ -2430,31 +2444,35 @@ inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, *output_inv_sqrt <<= -*output_shift; *output_shift = 0; } + *output_shift *= kReverseShift; } -inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, +inline void L2Normalization(const uint8* input_data, + const RuntimeShape& input_shape, int32 input_zero_point, uint8* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("L2Normalization/8bit"); - TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); for (int i = 0; i < outer_size; ++i) { int32 square_l2_norm = 0; for (int c = 0; c < depth; c++) { + // Note that input_data advances by depth in the second pass below. int32 diff = input_data[c] - input_zero_point; square_l2_norm += diff * diff; } int32 inv_l2norm_multiplier; int inv_l2norm_shift; - GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier, - &inv_l2norm_shift); + GetInvSqrtQuantizedMultiplierExp(square_l2_norm, &inv_l2norm_multiplier, + &inv_l2norm_shift); for (int c = 0; c < depth; c++) { int32 diff = *input_data - input_zero_point; int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp( - 128 * diff, inv_l2norm_multiplier, kReverseShift * inv_l2norm_shift); + 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); int32 unclamped_output_val = 128 + rescaled_diff; int32 output_val = std::min(255, std::max(0, unclamped_output_val)); *output_data = static_cast(output_val); @@ -2663,25 +2681,13 @@ inline void Add(int left_shift, const uint8* input1_data, output_activation_max, output_data); } -template inline void Add(const int16* input1_data, const Dims<4>& input1_dims, int input1_shift, const int16* input2_data, const Dims<4>& input2_dims, int input2_shift, int16 output_activation_min, int16 output_activation_max, int16* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("Add/Int16"); - // This is a copy of the reference implementation. We do not currently have a - // properly optimized version. - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, -32768); - TFLITE_DCHECK_EQ(output_activation_max, 32767); - } const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); @@ -2707,6 +2713,28 @@ inline void Add(const int16* input1_data, const Dims<4>& input1_dims, } } +template +inline void Add(const int16* input1_data, const Dims<4>& input1_dims, + int input1_shift, const int16* input2_data, + const Dims<4>& input2_dims, int input2_shift, + int16 output_activation_min, int16 output_activation_max, + int16* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, -32768); + TFLITE_DCHECK_EQ(output_activation_max, 32767); + } + + Add(input1_data, input1_dims, input1_shift, input2_data, input2_dims, + input2_shift, output_activation_min, output_activation_max, output_data, + output_dims); +} + template void Add(const int32* input1_data, const Dims<4>& input1_dims, const int32* input2_data, const Dims<4>& input2_dims, @@ -3207,19 +3235,6 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, output_data, output_dims); } -// TODO(aselle): This is not actually optimized yet. -inline void Div(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); - for (int i = 0; i < flat_size; i++) { - output_data[i] = ActivationFunctionWithMinMax( - input1_data[i] / input2_data[i], output_activation_min, - output_activation_max); - } -} - // TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then @@ -3385,105 +3400,6 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data, } } -template -void Concatenation(int concat_dim, const Scalar* const* input_data, - const Dims<4>* const* input_dims, int inputs_count, - Scalar* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Concatenation"); - int concat_size = 0; - for (int i = 0; i < inputs_count; i++) { - for (int j = 0; j < 4; j++) { - if (j != concat_dim) { - MatchingArraySize(*input_dims[i], j, output_dims, j); - } - } - concat_size += ArraySize(*input_dims[i], concat_dim); - } - TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim)); - TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); - // for now we dont have a model with a Concatenation - // with fused activation function. - TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); - int outer_size = 1; - for (int i = concat_dim + 1; i < 4; i++) { - outer_size *= output_dims.sizes[i]; - } - Scalar* output_ptr = output_data; - for (int k = 0; k < outer_size; k++) { - for (int i = 0; i < inputs_count; ++i) { - const int copy_size = - input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim]; - memcpy(output_ptr, input_data[i] + k * copy_size, - copy_size * sizeof(Scalar)); - output_ptr += copy_size; - } - } -} - -// TODO(prabhumk): This is the same as the reference implementation. -// TODO(prabhumk): The quantized implementation of concatentation isn't fully -// quantized as it takes scale as a floating point value. This should be fixed -// when optimizng this routine further. -inline void Concatenation(int concat_dim, const uint8* const* input_data, - const Dims<4>* const* input_dims, - const int32* input_zeropoint, - const float* input_scale, int inputs_count, - uint8* output_data, const Dims<4>& output_dims, - const int32 output_zeropoint, - const float output_scale) { - // The arguments input_zeropoint and input_scale are expected to be an array - // that have the quantization parameters for all the inputs to the concat - // operator. - gemmlowp::ScopedProfilingLabel label("Concatenation"); - TFLITE_DCHECK_GT(inputs_count, 1); - int concat_size = 0; - for (int i = 0; i < inputs_count; i++) { - for (int j = 0; j < 4; j++) { - if (j != concat_dim) { - MatchingArraySize(*input_dims[i], j, output_dims, j); - } - } - concat_size += ArraySize(*input_dims[i], concat_dim); - } - TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim)); - int outer_size = 1; - for (int i = concat_dim + 1; i < 4; i++) { - outer_size *= output_dims.sizes[i]; - } - const float inverse_output_scale = 1.f / output_scale; - uint8* output_ptr = output_data; - for (int k = 0; k < outer_size; k++) { - for (int i = 0; i < inputs_count; ++i) { - const int copy_size = - input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim]; - const uint8* input_ptr = input_data[i] + k * copy_size; - if (input_zeropoint[i] == output_zeropoint && - input_scale[i] == output_scale) { - memcpy(output_ptr, input_ptr, copy_size); - } else { - const float scale = input_scale[i] * inverse_output_scale; - const float bias = -input_zeropoint[i] * scale; - for (int j = 0; j < copy_size; ++j) { - const int32_t value = - static_cast(round(input_ptr[j] * scale + bias)) + - output_zeropoint; - output_ptr[j] = - static_cast(std::max(std::min(255, value), 0)); - } - } - output_ptr += copy_size; - } - } -} - -template -void DepthConcatenation(const Scalar* const* input_data, - const Dims<4>* const* input_dims, int inputs_count, - Scalar* output_data, const Dims<4>& output_dims) { - Concatenation(0, input_data, input_dims, inputs_count, - output_data, output_dims); -} - inline void LstmCell(const float* input_data, const Dims<4>& input_dims, const float* prev_activ_data, const Dims<4>& prev_activ_dims, const float* weights_data, @@ -3846,23 +3762,25 @@ inline int NodeOffset(int b, int h, int w, int height, int width) { return (b * height + h) * width + w; } -inline void AveragePool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int kwidth, int kheight, - float output_activation_min, +inline void AveragePool(const float* input_data, + const RuntimeShape& input_shape, int stride_width, + int stride_height, int pad_width, int pad_height, + int kwidth, int kheight, float output_activation_min, float output_activation_max, float* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("AveragePool"); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); // TODO(benoitjacob) make this a proper reference impl without Eigen! - const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); - auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape); + auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape); // TODO(benoitjacob) get rid of the dynamic memory allocation here! Eigen::VectorXf out_count(out_mat.cols()); out_count.setZero(); @@ -3900,9 +3818,9 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims, for (int y = 0; y < output_height; ++y) { for (int x = 0; x < output_width; ++x) { for (int c = 0; c < depth; ++c) { - output_data[Offset(output_dims, c, x, y, b)] = + output_data[Offset(output_shape, b, y, x, c)] = ActivationFunctionWithMinMax( - output_data[Offset(output_dims, c, x, y, b)], + output_data[Offset(output_shape, b, y, x, c)], output_activation_min, output_activation_max); } } @@ -3910,44 +3828,23 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void AveragePool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int kwidth, int kheight, float* output_data, - const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - - AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, kwidth, kheight, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, - int filter_height, float* output_data, - const Dims<4>& output_dims) { - AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_data, output_dims); -} - -inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, +inline void AveragePool(const uint8* input_data, + const RuntimeShape& input_shape, int stride_width, + int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, int32 output_activation_min, int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("AveragePool/8bit"); TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { @@ -3967,11 +3864,12 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, uint16 acc[kAccBufferMaxSize]; memset(acc, 0, depth * sizeof(acc[0])); const uint8* input_ptr = - input_data + input_dims.strides[1] * in_x_origin + - input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch; + input_data + + depth * (in_x_origin + + input_width * (in_y_origin + input_height * batch)); for (int fy = filter_y_start; fy < filter_y_end; fy++) { - const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] + - filter_x_start * input_dims.strides[1]; + const uint8* input_row_ptr = + input_ptr + depth * (fy * input_width + filter_x_start); for (int fx = filter_x_start; fx < filter_x_end; fx++) { int channel = 0; #ifdef USE_NEON @@ -4002,7 +3900,7 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, } } uint8* output_ptr = - output_data + Offset(output_dims, 0, out_x, out_y, batch); + output_data + Offset(output_shape, batch, out_y, out_x, 0); int channel = 0; #ifdef USE_NEON #define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \ @@ -4043,54 +3941,23 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void AveragePool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, - int filter_height, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -inline void MaxPool(const float* input_data, const Dims<4>& input_dims, +inline void MaxPool(const float* input_data, const RuntimeShape& input_shape, int stride_width, int stride_height, int pad_width, int pad_height, int kwidth, int kheight, float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("MaxPool"); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - - const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); - auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + + const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape); + auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape); // Prefill the output to minimum representable float value out_mat.setConstant(std::numeric_limits::lowest()); for (int b = 0; b < batches; ++b) { @@ -4123,9 +3990,9 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims, for (int y = 0; y < output_height; ++y) { for (int x = 0; x < output_width; ++x) { for (int c = 0; c < depth; ++c) { - output_data[Offset(output_dims, c, x, y, b)] = + output_data[Offset(output_shape, b, y, x, c)] = ActivationFunctionWithMinMax( - output_data[Offset(output_dims, c, x, y, b)], + output_data[Offset(output_shape, b, y, x, c)], output_activation_min, output_activation_max); } } @@ -4133,41 +4000,21 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void MaxPool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, int pad_height, - int kwidth, int kheight, float* output_data, - const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, kwidth, kheight, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, int filter_height, - float* output_data, const Dims<4>& output_dims) { - MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_data, output_dims); -} - -inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, +inline void MaxPool(const uint8* input_data, const RuntimeShape& input_shape, int stride_width, int stride_height, int pad_width, int pad_height, int filter_width, int filter_height, int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("MaxPool/8bit"); TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { @@ -4185,11 +4032,12 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, uint8 acc[kAccBufferMaxSize]; memset(acc, 0, depth * sizeof(acc[0])); const uint8* input_ptr = - input_data + input_dims.strides[1] * in_x_origin + - input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch; + input_data + + depth * (in_x_origin + + input_width * (in_y_origin + input_height * batch)); for (int fy = filter_y_start; fy < filter_y_end; fy++) { - const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] + - filter_x_start * input_dims.strides[1]; + const uint8* input_row_ptr = + input_ptr + depth * (fy * input_width + filter_x_start); for (int fx = filter_x_start; fx < filter_x_end; fx++) { int channel = 0; #ifdef USE_NEON @@ -4215,7 +4063,7 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, } } uint8* output_ptr = - output_data + Offset(output_dims, 0, out_x, out_y, batch); + output_data + Offset(output_shape, batch, out_y, out_x, 0); int channel = 0; #ifdef USE_NEON for (; channel <= depth - 16; channel += 16) { @@ -4242,53 +4090,23 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void MaxPool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -inline void L2Pool(const float* input_data, const Dims<4>& input_dims, +inline void L2Pool(const float* input_data, const RuntimeShape& input_shape, int stride_width, int stride_height, int pad_width, int pad_height, int filter_width, int filter_height, float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("L2Pool"); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); // Actually carry out L2 Pool. Code is written in forward mode: we go through // the input values once, and write to all the pooled regions that it maps to. - const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); - auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape); + auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape); Eigen::VectorXf in_square(in_mat.rows()); Eigen::VectorXf out_count(out_mat.cols()); out_count.setZero(); @@ -4330,28 +4148,6 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims, (out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt(); } -// legacy, for compatibility with old checked-in code -template -void L2Pool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, float* output_data, - const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - L2Pool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, int filter_height, - float* output_data, const Dims<4>& output_dims) { - L2Pool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_data, output_dims); -} - inline void LocalResponseNormalization(const float* input_data, const Dims<4>& input_dims, int range, float bias, float alpha, float beta, @@ -4397,14 +4193,14 @@ inline void LocalResponseNormalization(const float* input_data, } } -inline void Softmax(const float* input_data, const Dims<4>& input_dims, +inline void Softmax(const float* input_data, const RuntimeShape& input_shape, float beta, float* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Softmax"); - MatchingFlatSize(input_dims, output_dims); + MatchingFlatSize(input_shape, output_shape); - const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); - auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape); + auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape); // Compute the exponential first, removing the max coefficient for numerical // stability. out_mat = (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * beta; @@ -4416,10 +4212,10 @@ inline void Softmax(const float* input_data, const Dims<4>& input_dims, out_mat.array().rowwise() *= scale; } -inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, +inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape, int32 input_beta_multiplier, int32 input_beta_left_shift, int diff_min, uint8* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { // The representation chosen for the input to the exp() function is Q5.26. // We need to leave extra space since values that we skip might be as large as // -32 before multiplying by input_beta_multiplier, and therefore as large as @@ -4433,8 +4229,11 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, using FixedPoint0 = gemmlowp::FixedPoint; gemmlowp::ScopedProfilingLabel label("Softmax/8bit"); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int b = 0; b < outer_size; ++b) { const uint8* input_data_ptr = input_data + b * depth; @@ -4624,11 +4423,14 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, // TODO(myenik): This is the same as the reference implementation, not actually // optimized yet. -inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { +inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("LogSoftmax"); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { const float* block_input_data = input_data + i * depth; @@ -4769,11 +4571,11 @@ log_x_for_x_greater_than_or_equal_to_1( } // Currently just a copy of the reference code. -inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, +inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape, int32 input_multiplier, int32 input_left_shift, int32 reverse_scaling_divisor, int32 reverse_scaling_right_shift, int diff_min, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("LogSoftmax/Uint8"); // The representation chosen for the input to the exp() function is Q5.26. // We need to leave extra space since values that we skip might be as large as @@ -4788,8 +4590,11 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, using FixedPointAccum = gemmlowp::FixedPoint; using FixedPoint0 = gemmlowp::FixedPoint; - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { const uint8* block_input_data = input_data + i * depth; @@ -4853,21 +4658,21 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Logistic(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { +inline void Logistic(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Logistic"); - auto input_map = MapAsVector(input_data, input_dims); - auto output_map = MapAsVector(output_data, output_dims); + auto input_map = MapAsVector(input_data, input_shape); + auto output_map = MapAsVector(output_data, output_shape); output_map.array() = input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op()); } -inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, +inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape, int32 input_zero_point, int32 input_range_radius, int32 input_multiplier, int input_left_shift, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Logistic/Uint8"); - const int size = MatchingFlatSize(input_dims, output_dims); + const int size = MatchingFlatSize(input_shape, output_shape); int c = 0; #ifdef USE_NEON @@ -4999,10 +4804,10 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Logistic(const int16* input_data, const Dims<4>& input_dims, - int16* output_data, const Dims<4>& output_dims) { +inline void Logistic(const int16* input_data, const RuntimeShape& input_shape, + int16* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Logistic/Int16"); - const int flat_size = MatchingFlatSize(output_dims, input_dims); + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { } @@ -5059,21 +4864,21 @@ inline void Logistic(const int16* input_data, const Dims<4>& input_dims, } } -inline void Tanh(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { +inline void Tanh(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Tanh"); - auto input_map = MapAsVector(input_data, input_dims); - auto output_map = MapAsVector(output_data, output_dims); + auto input_map = MapAsVector(input_data, input_shape); + auto output_map = MapAsVector(output_data, output_shape); output_map.array() = input_map.array().tanh(); } -inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, +inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape, int32 input_zero_point, int32 input_range_radius, int32 input_multiplier, int input_left_shift, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { // Note that this is almost the exact same code as in Logistic(). gemmlowp::ScopedProfilingLabel label("Tanh"); - const int size = MatchingFlatSize(input_dims, output_dims); + const int size = MatchingFlatSize(input_shape, output_shape); int c = 0; int32_t output_zero_point = 128; @@ -5214,16 +5019,16 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Tanh(const int16* input_data, const Dims<4>& input_dims, +inline void Tanh(const int16* input_data, const RuntimeShape& input_shape, int input_left_shift, int16* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { gemmlowp::ScopedProfilingLabel label("Tanh/Int16"); // Support for shifts is limited until we have a parameterized version of // SaturatingRoundingMultiplyByPOT(). TFLITE_DCHECK_GE(input_left_shift, 0); TFLITE_DCHECK_LE(input_left_shift, 1); - const int flat_size = MatchingFlatSize(output_dims, input_dims); + const int flat_size = MatchingFlatSize(input_shape, output_shape); int c = 0; const int16* input_data_ptr = input_data; @@ -5314,49 +5119,6 @@ inline void Tanh(const int16* input_data, const Dims<4>& input_dims, } } -inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, - int32 zero_point, double scale, float* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Dequantize"); - const int flat_size = MatchingFlatSize(output_dims, input_dims); - for (int i = 0; i < flat_size; ++i) { - int32 val = input_data[i]; - float result = static_cast(scale * (val - zero_point)); - output_data[i] = result; - } -} - -inline void FakeQuant(const float* input_data, const Dims<4>& input_dims, - float rmin, float rmax, int num_bits, float* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("FakeQuant"); - - // 0 should always be a representable value. Let's assume that the initial - // min,max range contains 0. - TFLITE_DCHECK_LE(rmin, 0.0f); - TFLITE_DCHECK_GE(rmax, 0.0f); - TFLITE_DCHECK_LT(rmin, rmax); - - // Code matches tensorflow's FakeQuantWithMinMaxArgsFunctor. - int quant_min = 0; - int quant_max = (1 << num_bits) - 1; - float nudged_min, nudged_max, nudged_scale; - NudgeQuantizationRange(rmin, rmax, quant_min, quant_max, &nudged_min, - &nudged_max, &nudged_scale); - const float inv_nudged_scale = 1.0f / nudged_scale; - - const int flat_size = MatchingFlatSize(output_dims, input_dims); - for (int i = 0; i < flat_size; ++i) { - const float src_val = input_data[i]; - const float clamped = std::min(nudged_max, std::max(nudged_min, src_val)); - const float clamped_shifted = clamped - nudged_min; - const float dst_val = - TfLiteRound(clamped_shifted * inv_nudged_scale) * nudged_scale + - nudged_min; - output_data[i] = dst_val; - } -} - template inline void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data, const Dims<4>& output_dims) { @@ -5374,26 +5136,6 @@ inline void Floor(const float* input_data, const Dims<4>& input_dims, output_map.array() = Eigen::floor(input_map.array()); } -template -inline void Gather(const T* input_data, const Dims<4>& input_dims, - int input_rank, const int32* coords_data, - const Dims<4>& coords_dims, T* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Gather"); - - TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]); - int stride = input_dims.strides[input_rank - 1]; - T* out = output_data; - - for (int i = 0; i < coords_dims.sizes[0]; i++) { - TFLITE_DCHECK_GE(coords_data[i], 0); - TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]); - const T* in = input_data + coords_data[i] * stride; - memcpy(out, in, sizeof(T) * stride); - out += stride; - } -} - #ifdef USE_NEON inline void ResizeBilinearKernel(const float* input_ptr, int32 depth, float scale, float* output_ptr) { @@ -5722,6 +5464,46 @@ inline void ResizeBilinearGeneric(const float* input_data, } } +template +inline void ResizeBilinearGenericSmallChannel( + const T* input_data, const Dims<4>& input_dims, T* output_data, + const Dims<4>& output_dims, int32 batches, int32 input_height, + int32 input_width, int32 depth, int32 output_height, int32 output_width, + float height_scale, float width_scale) { + memset(output_data, 0, + batches * output_height * output_width * depth * sizeof(T)); + + T* output_ptr = &output_data[0]; + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < output_height; ++y) { + float input_y = y * height_scale; + int32 y0 = static_cast(std::floor(input_y)); + int32 y1 = std::min(y0 + 1, input_height - 1); + for (int x = 0; x < output_width; ++x) { + float input_x = x * width_scale; + int32 x0 = static_cast(input_x); + int32 x1 = std::min(x0 + 1, input_width - 1); + + int32 input_offset[4] = { + Offset(input_dims, 0, x0, y0, b), Offset(input_dims, 0, x1, y0, b), + Offset(input_dims, 0, x0, y1, b), Offset(input_dims, 0, x1, y1, b)}; + float scale[4] = {(1 - (input_y - y0)) * (1 - (input_x - x0)), + (1 - (input_y - y0)) * (input_x - x0), + (input_y - y0) * (1 - (input_x - x0)), + (input_y - y0) * (input_x - x0)}; + + for (int d = 0; d < depth; d++) { + const T* input_ptr = &input_data[d]; + *output_ptr++ = static_cast(input_ptr[input_offset[0]] * scale[0] + + input_ptr[input_offset[1]] * scale[1] + + input_ptr[input_offset[2]] * scale[2] + + input_ptr[input_offset[3]] * scale[3]); + } + } + } + } +} + inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, const int32* output_size_data, const Dims<4>& output_size_dims, float* output_data, @@ -5762,6 +5544,41 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, } } +// TODO(prabhumk): This is not a real quantized bilinear. It does not use int8 +// or int16 arithmetic. +inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, uint8* output_data, + const Dims<4>& output_dims, bool align_corners) { + gemmlowp::ScopedProfilingLabel label("ResizeBilinear"); + int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); + int32 input_height = ArraySize(input_dims, 2); + int32 input_width = ArraySize(input_dims, 1); + int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2); + int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)]; + int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)]; + + float height_scale = + (align_corners && output_height > 1) + ? (static_cast(input_height - 1) / (output_height - 1)) + : (static_cast(input_height) / output_height); + + float width_scale = + (align_corners && output_width > 1) + ? (static_cast(input_width - 1) / (output_width - 1)) + : (static_cast(input_width) / output_width); + + ResizeBilinearGenericSmallChannel( + input_data, input_dims, output_data, output_dims, batches, input_height, + input_width, depth, output_height, output_width, height_scale, + width_scale); +} + // legacy, for compatibility with old checked-in code inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, const int32* output_size_data, @@ -5771,53 +5588,13 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, output_data, output_dims, /*align_corners=*/false); } -template -inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, - const int32* block_shape_data, - const Dims<4>& block_shape_dims, - const int32* paddings_data, - const Dims<4>& paddings_dims, T* output_data, +// legacy, for compatibility with old checked-in code +inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, uint8* output_data, const Dims<4>& output_dims) { - // Unoptimized - Straight copy from reference ops. - gemmlowp::ScopedProfilingLabel label("SpaceToBatchND"); - - const int output_batch_size = ArraySize(output_dims, 3); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int input_batch_size = ArraySize(input_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int depth = ArraySize(input_dims, 0); - const int block_shape_height = block_shape_data[0]; - const int block_shape_width = block_shape_data[1]; - const int padding_top = paddings_data[0]; - const int padding_left = paddings_data[2]; - - for (int out_b = 0; out_b < output_batch_size; ++out_b) { - int input_batch = out_b % input_batch_size; - int shift_w = (out_b / input_batch_size) % block_shape_width; - int shift_h = (out_b / input_batch_size) / block_shape_width; - for (int out_h = 0; out_h < output_height; ++out_h) { - for (int out_w = 0; out_w < output_width; ++out_w) { - T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b); - if (out_h * block_shape_height + shift_h < padding_top || - out_h * block_shape_height + shift_h >= - padding_top + input_height || - out_w * block_shape_width + shift_w < padding_left || - out_w * block_shape_width + shift_w >= padding_left + input_width) { - memset(out, 0, depth * sizeof(T)); - } else { - const T* in = - input_data + - Offset(input_dims, 0, - (out_w * block_shape_width + shift_w) - padding_left, - (out_h * block_shape_height + shift_h) - padding_top, - input_batch); - memcpy(out, in, depth * sizeof(T)); - } - } - } - } + ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims, + output_data, output_dims, /*align_corners=*/false); } // Helper methods for BatchToSpaceND. @@ -6022,54 +5799,6 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims, output_dims, 0); } -// UNOPTIMIZED COPY of StridedSlice from reference_ops.h. -template -inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, - int begin_mask, int end_mask, - const std::vector& start_indices, - const std::vector& stop_indices, - const std::vector& strides, T* output_data, - const Dims<4>& output_dims) { - TFLITE_DCHECK_EQ(start_indices.size(), 4); - TFLITE_DCHECK_EQ(stop_indices.size(), 4); - TFLITE_DCHECK_EQ(strides.size(), 4); - const int start_b = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 3); - const int stop_b = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 3); - const int start_h = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 2); - const int stop_h = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 2); - const int start_w = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 1); - const int stop_w = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 1); - const int start_d = strided_slice::StartForAxis(begin_mask, start_indices, - strides, input_dims.sizes, 0); - const int stop_d = strided_slice::StopForAxis(end_mask, stop_indices, strides, - input_dims.sizes, 0); - - T* out_ptr = output_data; - for (int in_b = start_b; - !strided_slice::LoopCondition(in_b, stop_b, strides[3]); - in_b += strides[3]) { - for (int in_h = start_h; - !strided_slice::LoopCondition(in_h, stop_h, strides[2]); - in_h += strides[2]) { - for (int in_w = start_w; - !strided_slice::LoopCondition(in_w, stop_w, strides[1]); - in_w += strides[1]) { - for (int in_d = start_d; - !strided_slice::LoopCondition(in_d, stop_d, strides[0]); - in_d += strides[0]) { - *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)]; - } - } - } - } -} - template inline void Slice(const T* input_data, const Dims<4>& input_dims, const std::vector& begin, const std::vector& size, @@ -6104,41 +5833,6 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims, } } -template -inline void Mean(const T* input_data, const Dims<4>& input_dims, - const std::vector& reduction_indices, T* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Mean"); - const int output_batch = ArraySize(output_dims, 3); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int output_depth = ArraySize(output_dims, 0); - - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - - // The current implementation only supports simultaneous reduction over - // width and height. - TFLITE_DCHECK_EQ(reduction_indices.size(), 2); - TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) || - (reduction_indices[0] == 2 && reduction_indices[1] == 1)); - TFLITE_DCHECK_EQ(output_height, 1); - TFLITE_DCHECK_EQ(output_width, 1); - - for (int out_b = 0; out_b < output_batch; ++out_b) { - for (int out_d = 0; out_d < output_depth; ++out_d) { - float value = 0; - for (int in_h = 0; in_h < input_height; ++in_h) { - for (int in_w = 0; in_w < input_width; ++in_w) { - value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)]; - } - } - output_data[Offset(output_dims, out_d, 0, 0, out_b)] = - value / (input_width * input_height); - } - } -} - template void GenericBroadcastSub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, const Dims<4>& input2_dims, @@ -6218,130 +5912,84 @@ void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, output_map.array() = input1_map.array().max(max_value); } -template -void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, - T2* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("ArgMax"); - - // The current ArgMax implemention can only determine the index of the maximum - // value in the last dimension. So the axis argument is ignored. - - // For ArgMax, the number of output dimensions = (number of input dimensions - - // 1). For the sake of simplicity, the output dimensions are equal to the - // input dimensions here. We enforce the constraint that the last dimension - // must always be 1. - TFLITE_DCHECK_EQ(ArraySize(output_dims, 0), 1); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = ArraySize(input_dims, 0); - for (int i = 0; i < outer_size; ++i) { - auto max_value = *input_data; - ++input_data; - int max_index = 0; - for (int d = 1; d < depth; ++d) { - const auto& curr_value = *input_data; - if (curr_value > max_value) { - max_value = curr_value; - max_index = d; - } - ++input_data; - } - *output_data = max_index; - ++output_data; - } -} - template -void Transpose(const T* input, const Dims<4>& input_dims, T* output, - const Dims<4>& output_dims, const int* permuted_axes) { - int out_sizes[4]; - // Compute the inverse permutation array so we can do an output centered - // transpose. Also, check to make sure output_dims is matching input_dims. - for (int k = 0; k < 4; k++) { - out_sizes[k] = - MatchingArraySize(input_dims, permuted_axes[k], output_dims, k); - } - - // Naive transpose loop (iterate on output index and compute input index). - int o[4]; // loop index (on output). - int i[4]; - for (o[3] = 0; o[3] < out_sizes[3]; o[3]++) { - i[permuted_axes[3]] = o[3]; - for (o[2] = 0; o[2] < out_sizes[2]; o[2]++) { - i[permuted_axes[2]] = o[2]; - for (o[1] = 0; o[1] < out_sizes[1]; o[1]++) { - i[permuted_axes[1]] = o[1]; - for (o[0] = 0; o[0] < out_sizes[0]; o[0]++) { - i[permuted_axes[0]] = o[0]; - output[Offset(output_dims, o)] = input[Offset(input_dims, i)]; - } - } - } - } -} +void TransposeIm2col(const T* input_data, const Dims<4>& input_dims, + const Dims<4>& filter_dims, int stride_width, + int stride_height, int pad_width, int pad_height, + const Dims<4>& output_dims, uint8 zero_byte, + T* im2col_data) { + gemmlowp::ScopedProfilingLabel label("TransposeIm2col"); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + TFLITE_DCHECK(im2col_data); -inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, - const float* filter_data, const Dims<4>& filter_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, float* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("TransposeConv"); - // THIS FUNCTION IS A COPY FROM reference_ops.h. - // To optimize, start by using the conv code with transposed weights for the - // case of stride_height = stride_width = 1. const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3); - const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); const int input_height = ArraySize(input_dims, 2); const int input_width = ArraySize(input_dims, 1); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3); const int filter_height = ArraySize(filter_dims, 2); const int filter_width = ArraySize(filter_dims, 1); const int output_height = ArraySize(output_dims, 2); const int output_width = ArraySize(output_dims, 1); + MatchingArraySize(output_dims, 0, filter_dims, 0); // output_depth - // Although transpose convolution simplifies to convolution with transposed - // weights for strides of 1, non-unitary striding complicates matters. To - // keep this reference implementation as clear as possible, we use a "scatter" - // access pattern, where we loop through all the input elements, computing - // their influence on the output, rather than looping through the output - // elements in the typical "gather" access pattern of a conv. We therefore - // must initialize the output array to zero. - for (int batch = 0; batch < batches; ++batch) { - for (int out_y = 0; out_y < output_height; ++out_y) { - for (int out_x = 0; out_x < output_width; ++out_x) { - for (int out_channel = 0; out_channel < output_depth; ++out_channel) { - output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] = - 0.0f; - } - } - } - } + // Construct the MxN sized im2col matrix. + // The rows M, are sub-ordered B x H x W + Dims<4> row_dims; + row_dims.sizes[0] = output_width; + row_dims.sizes[1] = output_height; + row_dims.sizes[2] = batches; + row_dims.sizes[3] = 1; + ComputeStrides(&row_dims); - // Loop through input elements one at a time. + // The columns, N, are sub-ordered Kh x Kw x Din + Dims<4> col_dims; + col_dims.sizes[0] = input_depth; + col_dims.sizes[1] = filter_width; + col_dims.sizes[2] = filter_height; + col_dims.sizes[3] = 1; + ComputeStrides(&col_dims); + + // Use dimensions M and N to construct dims for indexing directly into im2col + Dims<4> im2col_dims; + im2col_dims.sizes[0] = FlatSize(col_dims); + im2col_dims.sizes[1] = FlatSize(row_dims); + im2col_dims.sizes[2] = 1; + im2col_dims.sizes[3] = 1; + ComputeStrides(&im2col_dims); + + // Build the im2col matrix by looping through all the input pixels, + // computing their influence on the output, rather than looping through all + // the output pixels. We therefore must initialize the im2col array to zero. + // This is potentially inefficient because we subsequently overwrite bytes + // set here. However, in practice memset is very fast and costs negligible. + memset(im2col_data, zero_byte, FlatSize(im2col_dims) * sizeof(T)); + + // Loop through the output batches for (int batch = 0; batch < batches; ++batch) { + // Loop through input pixels one at a time. for (int in_y = 0; in_y < input_height; ++in_y) { for (int in_x = 0; in_x < input_width; ++in_x) { - for (int in_channel = 0; in_channel < input_depth; ++in_channel) { - // Loop through the output elements it will influence - const int out_x_origin = (in_x * stride_width) - pad_width; - const int out_y_origin = (in_y * stride_height) - pad_height; - for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + // Loop through the output pixels it will influence + const int out_x_origin = (in_x * stride_width) - pad_width; + const int out_y_origin = (in_y * stride_height) - pad_height; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + const int out_y = out_y_origin + filter_y; + // Is output pixel within height bounds? + if ((out_y >= 0) && (out_y < output_height)) { for (int filter_x = 0; filter_x < filter_width; ++filter_x) { - for (int out_channel = 0; out_channel < output_depth; - ++out_channel) { - // Compute output element location - const int out_x = out_x_origin + filter_x; - const int out_y = out_y_origin + filter_y; - // We cannot accumulate out of bounds - if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) && - (out_y < output_height)) { - float input_value = input_data[Offset(input_dims, in_channel, - in_x, in_y, batch)]; - float filter_value = - filter_data[Offset(filter_dims, out_channel, filter_x, - filter_y, in_channel)]; - output_data[Offset(output_dims, out_channel, out_x, out_y, - batch)] += input_value * filter_value; - } + const int out_x = out_x_origin + filter_x; + // Is output pixel within width bounds? + if ((out_x >= 0) && (out_x < output_width)) { + // Copy the input elements of this pixel + T const* src = + input_data + Offset(input_dims, 0, in_x, in_y, batch); + T* dst = im2col_data + + Offset(im2col_dims, + Offset(col_dims, 0, filter_x, filter_y, 0), + Offset(row_dims, out_x, out_y, batch, 0), 0, 0); + memcpy(dst, src, input_depth * sizeof(T)); } } } @@ -6351,6 +5999,31 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, } } +inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + gemmlowp::ScopedProfilingLabel label("TransposeConv"); + + // Note we could use transposed weights with forward conv for unstrided + // cases. But we are already getting good performance with this code as-is. + TFLITE_DCHECK(im2col_data); + TransposeIm2col(input_data, input_dims, filter_dims, stride_width, + stride_height, pad_width, pad_height, output_dims, 0, + im2col_data); + + const auto im2col_matrix_map = + MapAsMatrixWithFirstDimAsRows(im2col_data, im2col_dims); + const auto filter_matrix_map = + MapAsMatrixWithLastDimAsCols(filter_data, filter_dims); + auto output_matrix_map = + MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + + Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map); +} + } // namespace optimized_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc index b0951aac8cbb98a181d9dcaef88770fadfc74f62..e224980493aa11f642da103ee7d7377b6c4b1da0 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + #include #include #include @@ -48,15 +49,15 @@ void QuantizeMultiplierGreaterThanOne(double double_multiplier, TFLITE_CHECK_GE(*left_shift, 0); } -void QuantizeMultiplierSmallerThanOne(double double_multiplier, - int32_t* quantized_multiplier, - int* right_shift) { +void QuantizeMultiplierSmallerThanOneExp(double double_multiplier, + int32_t* quantized_multiplier, + int* left_shift) { TFLITE_CHECK_LT(double_multiplier, 1.); TFLITE_CHECK_GT(double_multiplier, 0.); int shift; QuantizeMultiplier(double_multiplier, quantized_multiplier, &shift); TFLITE_CHECK_LE(shift, 0); - *right_shift = -shift; + *left_shift = shift; } void PreprocessSoftmaxScaling(double beta, double input_scale, @@ -78,20 +79,21 @@ void PreprocessSoftmaxScaling(double beta, double input_scale, quantized_multiplier, left_shift); } -void PreprocessLogSoftmaxScaling(double beta, double input_scale, - int input_integer_bits, - int32_t* quantized_multiplier, int* left_shift, - int32_t* reverse_scaling_divisor, - int* reverse_scaling_right_shift) { +void PreprocessLogSoftmaxScalingExp(double beta, double input_scale, + int input_integer_bits, + int32_t* quantized_multiplier, + int* left_shift, + int32_t* reverse_scaling_divisor, + int* reverse_scaling_left_shift) { PreprocessSoftmaxScaling(beta, input_scale, input_integer_bits, quantized_multiplier, left_shift); // Also calculate what amounts to the inverse scaling factor for the input. const double real_reverse_scaling_divisor = (1 << (31 - *left_shift)) / static_cast(*quantized_multiplier); - tflite::QuantizeMultiplierSmallerThanOne(real_reverse_scaling_divisor, - reverse_scaling_divisor, - reverse_scaling_right_shift); + tflite::QuantizeMultiplierSmallerThanOneExp(real_reverse_scaling_divisor, + reverse_scaling_divisor, + reverse_scaling_left_shift); } int CalculateInputRadius(int input_integer_bits, int input_left_shift) { @@ -125,4 +127,16 @@ void NudgeQuantizationRange(const float min, const float max, *nudged_max = (quant_max_float - nudged_zero_point) * (*scale); } +bool CheckedLog2(const float x, int* log2_result) { + // Using TfLiteRound instead of std::round and std::log instead of + // std::log2 to work around these fuctions being missing in a toolchain + // used in some TensorFlow tests as of May 2018. + const float x_log2 = std::log(x) * (1.0f / std::log(2.0f)); + const float x_log2_rounded = TfLiteRound(x_log2); + const float x_log2_fracpart = x_log2 - x_log2_rounded; + + *log2_result = static_cast(x_log2_rounded); + return std::abs(x_log2_fracpart) < 1e-3; +} + } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h index 4a217515f142b2451ebd61e423871b95cdc09748..525857a2e6f73276d0a6e64770947169033c7667 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h @@ -167,9 +167,9 @@ IntOut SafeCast(FloatIn x) { // this is intended as a RIGHT-shift. // // Restricted to the case where the multiplier < 1 (and non-negative). -void QuantizeMultiplierSmallerThanOne(double double_multiplier, - int32_t* quantized_multiplier, - int* right_shift); +void QuantizeMultiplierSmallerThanOneExp(double double_multiplier, + int32_t* quantized_multiplier, + int* left_shift); // Decompose a double multiplier into a Q0.31 int32 representation of its // significand, and shift representation of its exponent. @@ -197,11 +197,12 @@ void PreprocessSoftmaxScaling(double beta, double input_scale, int input_integer_bits, int32_t* quantized_multiplier, int* left_shift); // Like PreprocessSoftmaxScaling, but inverse scaling factors also calculated. -void PreprocessLogSoftmaxScaling(double beta, double input_scale, - int input_integer_bits, - int32_t* quantized_multiplier, int* left_shift, - int32_t* reverse_scaling_divisor, - int* reverse_scaling_right_shift); +void PreprocessLogSoftmaxScalingExp(double beta, double input_scale, + int input_integer_bits, + int32_t* quantized_multiplier, + int* left_shift, + int32_t* reverse_scaling_divisor, + int* reverse_scaling_left_shift); // Calculate the largest input that will result in a within-bounds intermediate // result within MultiplyByQuantizedMultiplierGreaterThanOne. In other words, // it must not overflow before we reduce the value by multiplication by the @@ -217,6 +218,11 @@ void NudgeQuantizationRange(const float min, const float max, const int quant_min, const int quant_max, float* nudged_min, float* nudged_max, float* scale); +// If x is approximately a power of two (with any positive or negative +// exponent), stores that exponent (i.e. log2(x)) in *log2_result, otherwise +// returns false. +bool CheckedLog2(const float x, int* log2_result); + } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_QUANTIZATION_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc index 2d74b3d3849812a2dc95fabcd680aa280c99ca55..94773b47d3817d7ed7240f74545ad04e7fa4bd52 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc @@ -196,21 +196,21 @@ TEST(QuantizationUtilTest, ChooseQuantizationParamsInvalidRange) { EXPECT_DEATH(ChooseQuantizationParams(10.0, -30.0), ""); } -TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOne) { +TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOneExp) { auto quantize = [](double d) { int32_t q; int s; - QuantizeMultiplierSmallerThanOne(d, &q, &s); + QuantizeMultiplierSmallerThanOneExp(d, &q, &s); return std::pair{q, s}; }; EXPECT_DEATH(quantize(-0.1), ""); EXPECT_DEATH(quantize(0.0), ""); - EXPECT_THAT(quantize(0.25), Pair(1073741824, 1)); + EXPECT_THAT(quantize(0.25), Pair(1073741824, -1)); // Around 0.5 we can see the change in exponent and how we try hard to // void hitting max int32. - EXPECT_THAT(quantize(0.50 - 5e-9), Pair(2147483627, 1)); + EXPECT_THAT(quantize(0.50 - 5e-9), Pair(2147483627, -1)); EXPECT_THAT(quantize(0.50 - 1e-10), Pair(1073741824, 0)); EXPECT_THAT(quantize(0.50), Pair(1073741824, 0)); diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..878b2441b4f2828a014673f5bd80fb8aa29514db --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h @@ -0,0 +1,332 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_ + +#include +#include + +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { + +namespace reference_ops { + +inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) { + return RuntimeShape( + {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]}); +} + +template +void L2Normalization(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + L2Normalization(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, uint8* output_data, + const Dims<4>& output_dims) { + L2Normalization(input_data, DimsToShape(input_dims), input_zero_point, + output_data, DimsToShape(output_dims)); +} + +inline void Relu(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Relu(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Relu1(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Relu1(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Relu6(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Relu6(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, float* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, filter_width, filter_height, + output_activation_min, output_activation_max, output_data, + DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int kwidth, int kheight, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, filter_width, filter_height, + output_activation_min, output_activation_max, output_data, + DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + L2Pool(input_data, DimsToShape(input_dims), stride_width, stride_height, + pad_width, pad_height, filter_width, filter_height, + output_activation_min, output_activation_max, output_data, + DimsToShape(output_dims)); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + L2Pool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + L2Pool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void Softmax(const float* input_data, const Dims<4>& input_dims, + float beta, float* output_data, + const Dims<4>& output_dims) { + Softmax(input_data, DimsToShape(input_dims), beta, output_data, + DimsToShape(output_dims)); +} + +inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, + int32 input_beta_multiplier, int32 input_beta_left_shift, + int diff_min, uint8* output_data, + const Dims<4>& output_dims) { + Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier, + input_beta_left_shift, diff_min, output_data, + DimsToShape(output_dims)); +} + +inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + LogSoftmax(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, + int32 input_multiplier, int32 input_left_shift, + int32 reverse_scaling_divisor, + int32 reverse_scaling_right_shift, int diff_min, + uint8* output_data, const Dims<4>& output_dims) { + LogSoftmax(input_data, DimsToShape(input_dims), input_multiplier, + input_left_shift, reverse_scaling_divisor, + reverse_scaling_right_shift, diff_min, output_data, + DimsToShape(output_dims)); +} + +inline void Logistic(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Logistic(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + Logistic(input_data, DimsToShape(input_dims), input_zero_point, + input_range_radius, input_multiplier, input_left_shift, output_data, + DimsToShape(output_dims)); +} + +inline void Logistic(const int16* input_data, const Dims<4>& input_dims, + int16* output_data, const Dims<4>& output_dims) { + Logistic(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Tanh(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Tanh(input_data, DimsToShape(input_dims), output_data, + DimsToShape(output_dims)); +} + +inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + Tanh(input_data, DimsToShape(input_dims), input_zero_point, + input_range_radius, input_multiplier, input_left_shift, output_data, + DimsToShape(output_dims)); +} + +inline void Tanh(const int16* input_data, const Dims<4>& input_dims, + int input_left_shift, int16* output_data, + const Dims<4>& output_dims) { + Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data, + DimsToShape(output_dims)); +} + +} // namespace reference_ops +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 4781bbc70a30a4be1c932a59f60e229dc05fad17..89ec0eb266f09e3fe18a424dc90a1a3056d94a55 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -914,9 +914,9 @@ void GlobalBatchNormalization(const float* input_data, } } -inline void Relu(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(input_dims, output_dims); +inline void Relu(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { const float val = input_data[i]; const float lower = 0; @@ -925,9 +925,10 @@ inline void Relu(const float* input_data, const Dims<4>& input_dims, } } -inline void Relu1(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(input_dims, output_dims); +inline void Relu1(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)"); + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { const float val = input_data[i]; const float upper = 1; @@ -937,9 +938,10 @@ inline void Relu1(const float* input_data, const Dims<4>& input_dims, } } -inline void Relu6(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(input_dims, output_dims); +inline void Relu6(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)"); + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; ++i) { const float val = input_data[i]; const float upper = 6; @@ -950,11 +952,14 @@ inline void Relu6(const float* input_data, const Dims<4>& input_dims, } template -void L2Normalization(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { +void L2Normalization(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { static_assert(Ac == FusedActivationFunctionType::kNone, ""); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { float squared_l2_norm = 0; for (int c = 0; c < depth; ++c) { @@ -968,8 +973,9 @@ void L2Normalization(const float* input_data, const Dims<4>& input_dims, } } -inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, - int* output_shift) { +inline void GetInvSqrtQuantizedMultiplierExp(int32 input, + int32* output_inv_sqrt, + int* output_shift) { *output_shift = 11; while (input >= (1 << 29)) { input /= 4; @@ -1011,34 +1017,36 @@ inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, *output_inv_sqrt <<= -*output_shift; *output_shift = 0; } + *output_shift *= kReverseShift; } -inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, +inline void L2Normalization(const uint8* input_data, + const RuntimeShape& input_shape, int32 input_zero_point, uint8* output_data, - const Dims<4>& output_dims) { - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); + const RuntimeShape& output_shape) { + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); for (int i = 0; i < outer_size; ++i) { int32 square_l2_norm = 0; for (int c = 0; c < depth; c++) { - int32 diff = - input_data[Offset(input_dims, c, i, 0, 0)] - input_zero_point; + int32 diff = input_data[depth * i + c] - input_zero_point; square_l2_norm += diff * diff; } int32 inv_l2norm_multiplier; int inv_l2norm_shift; - GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier, - &inv_l2norm_shift); + GetInvSqrtQuantizedMultiplierExp(square_l2_norm, &inv_l2norm_multiplier, + &inv_l2norm_shift); for (int c = 0; c < depth; c++) { - int32 diff = - input_data[Offset(input_dims, c, i, 0, 0)] - input_zero_point; + int32 diff = input_data[depth * i + c] - input_zero_point; int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp( - 128 * diff, inv_l2norm_multiplier, kReverseShift * inv_l2norm_shift); + 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); int32 unclamped_output_val = 128 + rescaled_diff; int32 output_val = std::min(255, std::max(0, unclamped_output_val)); - output_data[Offset(output_dims, c, i, 0, 0)] = - static_cast(output_val); + output_data[depth * i + c] = static_cast(output_val); } } } @@ -1128,22 +1136,12 @@ inline void Add(int left_shift, const uint8* input1_data, } } -template inline void Add(const int16* input1_data, const Dims<4>& input1_dims, int input1_shift, const int16* input2_data, const Dims<4>& input2_dims, int input2_shift, int16 output_activation_min, int16 output_activation_max, int16* output_data, const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, -32768); - TFLITE_DCHECK_EQ(output_activation_max, 32767); - } const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); @@ -1169,6 +1167,28 @@ inline void Add(const int16* input1_data, const Dims<4>& input1_dims, } } +template +inline void Add(const int16* input1_data, const Dims<4>& input1_dims, + int input1_shift, const int16* input2_data, + const Dims<4>& input2_dims, int input2_shift, + int16 output_activation_min, int16 output_activation_max, + int16* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, -32768); + TFLITE_DCHECK_EQ(output_activation_max, 32767); + } + + Add(input1_data, input1_dims, input1_shift, input2_data, input2_dims, + input2_shift, output_activation_min, output_activation_max, output_data, + output_dims); +} + // TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then @@ -1749,7 +1769,6 @@ template void Concatenation(int concat_dim, const Scalar* const* input_data, const Dims<4>* const* input_dims, int inputs_count, Scalar* output_data, const Dims<4>& output_dims) { - TFLITE_DCHECK_GT(inputs_count, 1); int concat_size = 0; for (int i = 0; i < inputs_count; i++) { for (int j = 0; j < 4; j++) { @@ -1760,7 +1779,9 @@ void Concatenation(int concat_dim, const Scalar* const* input_data, concat_size += ArraySize(*input_dims[i], concat_dim); } TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim)); - TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + // For now we don't have a model with a Concatenation with fused activation. + TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone); int outer_size = 1; for (int i = concat_dim + 1; i < 4; i++) { outer_size *= output_dims.sizes[i]; @@ -2238,18 +2259,21 @@ inline int NodeOffset(int b, int h, int w, int height, int width) { return (b * height + h) * width + w; } -inline void AveragePool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, +inline void AveragePool(const float* input_data, + const RuntimeShape& input_shape, int stride_width, + int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, float output_activation_min, float output_activation_max, float* output_data, - const Dims<4>& output_dims) { - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + const RuntimeShape& output_shape) { + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { @@ -2273,12 +2297,12 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims, const int in_x = in_x_origin + filter_x; const int in_y = in_y_origin + filter_y; total += - input_data[Offset(input_dims, channel, in_x, in_y, batch)]; + input_data[Offset(input_shape, batch, in_y, in_x, channel)]; filter_count++; } } const float average = total / filter_count; - output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + output_data[Offset(output_shape, batch, out_y, out_x, channel)] = ActivationFunctionWithMinMax(average, output_activation_min, output_activation_max); } @@ -2287,42 +2311,22 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void AveragePool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - float* output_data, const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, - int filter_height, float* output_data, - const Dims<4>& output_dims) { - AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_data, output_dims); -} - -inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, +inline void AveragePool(const uint8* input_data, + const RuntimeShape& input_shape, int stride_width, + int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, int32 output_activation_min, int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { @@ -2345,14 +2349,15 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, ++filter_x) { const int in_x = in_x_origin + filter_x; const int in_y = in_y_origin + filter_y; - acc += input_data[Offset(input_dims, channel, in_x, in_y, batch)]; + acc += + input_data[Offset(input_shape, batch, in_y, in_x, channel)]; filter_count++; } } acc = (acc + filter_count / 2) / filter_count; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); - output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + output_data[Offset(output_shape, batch, out_y, out_x, channel)] = static_cast(acc); } } @@ -2360,50 +2365,19 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void AveragePool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, - int filter_height, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -inline void L2Pool(const float* input_data, const Dims<4>& input_dims, +inline void L2Pool(const float* input_data, const RuntimeShape& input_shape, int stride_width, int stride_height, int pad_width, int pad_height, int filter_width, int filter_height, float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + float* output_data, const RuntimeShape& output_shape) { + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { @@ -2427,13 +2401,13 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims, const int in_x = in_x_origin + filter_x; const int in_y = in_y_origin + filter_y; const float val = - input_data[Offset(input_dims, channel, in_x, in_y, batch)]; + input_data[Offset(input_shape, batch, in_y, in_x, channel)]; sum_squares += val * val; filter_count++; } } const float l2pool_result = std::sqrt(sum_squares / filter_count); - output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + output_data[Offset(output_shape, batch, out_y, out_x, channel)] = ActivationFunctionWithMinMax(l2pool_result, output_activation_min, output_activation_max); } @@ -2442,40 +2416,19 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void L2Pool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, float* output_data, - const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - - L2Pool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, int filter_height, - float* output_data, const Dims<4>& output_dims) { - L2Pool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_data, output_dims); -} - -inline void MaxPool(const float* input_data, const Dims<4>& input_dims, +inline void MaxPool(const float* input_data, const RuntimeShape& input_shape, int stride_width, int stride_height, int pad_width, int pad_height, int filter_width, int filter_height, float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + float* output_data, const RuntimeShape& output_shape) { + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { @@ -2499,10 +2452,10 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims, const int in_y = in_y_origin + filter_y; max = std::max( max, - input_data[Offset(input_dims, channel, in_x, in_y, batch)]); + input_data[Offset(input_shape, batch, in_y, in_x, channel)]); } } - output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + output_data[Offset(output_shape, batch, out_y, out_x, channel)] = ActivationFunctionWithMinMax(max, output_activation_min, output_activation_max); } @@ -2511,42 +2464,22 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void MaxPool(const float* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, float* output_data, - const Dims<4>& output_dims) { - float output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, int filter_height, - float* output_data, const Dims<4>& output_dims) { - MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_data, output_dims); -} - -inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, +inline void MaxPool(const uint8* input_data, const RuntimeShape& input_shape, int stride_width, int stride_height, int pad_width, int pad_height, int filter_width, int filter_height, int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { TFLITE_DCHECK_LE(output_activation_min, output_activation_max); TFLITE_DCHECK_GE(output_activation_min, 0); TFLITE_DCHECK_LE(output_activation_max, 255); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int batches = MatchingDim(input_shape, 0, output_shape, 0); + const int depth = MatchingDim(input_shape, 3, output_shape, 3); + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); for (int batch = 0; batch < batches; ++batch) { for (int out_y = 0; out_y < output_height; ++out_y) { for (int out_x = 0; out_x < output_width; ++out_x) { @@ -2570,12 +2503,12 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, const int in_y = in_y_origin + filter_y; max = std::max( max, - input_data[Offset(input_dims, channel, in_x, in_y, batch)]); + input_data[Offset(input_shape, batch, in_y, in_x, channel)]); } } max = std::max(max, output_activation_min); max = std::min(max, output_activation_max); - output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + output_data[Offset(output_shape, batch, out_y, out_x, channel)] = static_cast(max); } } @@ -2583,38 +2516,6 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, } } -// legacy, for compatibility with old checked-in code -template -void MaxPool(const uint8* input_data, const Dims<4>& input_dims, - int stride_width, int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { - static_assert(Ac == FusedActivationFunctionType::kNone || - Ac == FusedActivationFunctionType::kRelu || - Ac == FusedActivationFunctionType::kRelu6 || - Ac == FusedActivationFunctionType::kRelu1, - ""); - if (Ac == FusedActivationFunctionType::kNone) { - TFLITE_DCHECK_EQ(output_activation_min, 0); - TFLITE_DCHECK_EQ(output_activation_max, 255); - } - MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - -// legacy, for compatibility with old checked-in code -template -void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, - filter_width, filter_height, output_activation_min, - output_activation_max, output_data, output_dims); -} - inline void LocalResponseNormalization(const float* input_data, const Dims<4>& input_dims, int range, float bias, float alpha, float beta, @@ -2638,11 +2539,14 @@ inline void LocalResponseNormalization(const float* input_data, } } -inline void Softmax(const float* input_data, const Dims<4>& input_dims, +inline void Softmax(const float* input_data, const RuntimeShape& input_shape, float beta, float* output_data, - const Dims<4>& output_dims) { - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const RuntimeShape& output_shape) { + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { // Find max element value which we'll use to ensure numerical stability @@ -2667,10 +2571,10 @@ inline void Softmax(const float* input_data, const Dims<4>& input_dims, } } -inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, +inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape, int32 input_beta_multiplier, int32 input_beta_left_shift, int diff_min, uint8* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { // The representation chosen for the input to the exp() function is Q5.26. // We need to leave extra space since values that we skip might be as large as // -32 before multiplying by input_beta_multiplier, and therefore as large as @@ -2683,8 +2587,11 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, using FixedPointAccum = gemmlowp::FixedPoint; using FixedPoint0 = gemmlowp::FixedPoint; - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { uint8 max_in_row = 0; @@ -2745,10 +2652,13 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, } } -inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); +inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { // Find max element value which we'll use to ensure numerical stability @@ -2888,11 +2798,11 @@ log_x_for_x_greater_than_or_equal_to_1( input_val); } -inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, +inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape, int32 input_multiplier, int32 input_left_shift, int32 reverse_scaling_divisor, int32 reverse_scaling_right_shift, int diff_min, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { // The representation chosen for the input to the exp() function is Q5.26. // We need to leave extra space since values that we skip might be as large as // -32 before multiplying by input_beta_multiplier, and therefore as large as @@ -2906,8 +2816,11 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, using FixedPointAccum = gemmlowp::FixedPoint; using FixedPoint0 = gemmlowp::FixedPoint; - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { uint8 max_in_row = 0; @@ -2971,9 +2884,9 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Logistic(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input_dims); +inline void Logistic(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { float val = input_data[i]; @@ -2982,11 +2895,11 @@ inline void Logistic(const float* input_data, const Dims<4>& input_dims, } } -inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, +inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape, int32 input_zero_point, int32 input_range_radius, int32 input_multiplier, int input_left_shift, - uint8* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input_dims); + uint8* output_data, const RuntimeShape& output_shape) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { const uint8 input_val_u8 = input_data[i]; @@ -3020,9 +2933,9 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Logistic(const int16* input_data, const Dims<4>& input_dims, - int16* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input_dims); +inline void Logistic(const int16* input_data, const RuntimeShape& input_shape, + int16* output_data, const RuntimeShape& output_shape) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { // F0 uses 0 integer bits, range [-1, 1]. @@ -3038,9 +2951,9 @@ inline void Logistic(const int16* input_data, const Dims<4>& input_dims, } } -inline void Tanh(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input_dims); +inline void Tanh(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { float val = input_data[i]; @@ -3049,12 +2962,12 @@ inline void Tanh(const float* input_data, const Dims<4>& input_dims, } } -inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, +inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape, int32 input_zero_point, int32 input_range_radius, int32 input_multiplier, int input_left_shift, - uint8* output_data, const Dims<4>& output_dims) { + uint8* output_data, const RuntimeShape& output_shape) { const int32 output_zero_point = 128; - const int flat_size = MatchingFlatSize(output_dims, input_dims); + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { const uint8 input_val_u8 = input_data[i]; @@ -3089,15 +3002,15 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, } } -inline void Tanh(const int16* input_data, const Dims<4>& input_dims, +inline void Tanh(const int16* input_data, const RuntimeShape& input_shape, int input_left_shift, int16* output_data, - const Dims<4>& output_dims) { + const RuntimeShape& output_shape) { // Support for shifts is limited until we have a parameterized version of // SaturatingRoundingMultiplyByPOT(). TFLITE_DCHECK_GE(input_left_shift, 0); TFLITE_DCHECK_LE(input_left_shift, 1); - const int flat_size = MatchingFlatSize(output_dims, input_dims); + const int flat_size = MatchingFlatSize(input_shape, output_shape); // F0 uses 0 integer bits, range [-1, 1]. // This is the return type of math functions such as tanh, logistic, @@ -3202,9 +3115,10 @@ inline void Gather(const T* input_data, const Dims<4>& input_dims, } } -inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, +template +inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims, const int32* output_size_data, - const Dims<4>& output_size_dims, float* output_data, + const Dims<4>& output_size_dims, T* output_data, const Dims<4>& output_dims, bool align_corners) { int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); int32 input_height = ArraySize(input_dims, 2); @@ -3236,15 +3150,15 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, int32 x0 = static_cast(std::floor(input_x)); int32 x1 = std::min(x0 + 1, input_width - 1); for (int c = 0; c < depth; ++c) { - float interpolation = input_data[Offset(input_dims, c, x0, y0, b)] * - (1 - (input_y - y0)) * - (1 - (input_x - x0)) + - input_data[Offset(input_dims, c, x0, y1, b)] * - (input_y - y0) * (1 - (input_x - x0)) + - input_data[Offset(input_dims, c, x1, y0, b)] * - (1 - (input_y - y0)) * (input_x - x0) + - input_data[Offset(input_dims, c, x1, y1, b)] * - (input_y - y0) * (input_x - x0); + T interpolation = + static_cast(input_data[Offset(input_dims, c, x0, y0, b)] * + (1 - (input_y - y0)) * (1 - (input_x - x0)) + + input_data[Offset(input_dims, c, x0, y1, b)] * + (input_y - y0) * (1 - (input_x - x0)) + + input_data[Offset(input_dims, c, x1, y0, b)] * + (1 - (input_y - y0)) * (input_x - x0) + + input_data[Offset(input_dims, c, x1, y1, b)] * + (input_y - y0) * (input_x - x0)); output_data[Offset(output_dims, c, x, y, b)] = interpolation; } } @@ -3257,8 +3171,18 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, const int32* output_size_data, const Dims<4>& output_size_dims, float* output_data, const Dims<4>& output_dims) { - ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims, - output_data, output_dims, /*align_corners=*/false); + ResizeBilinear(input_data, input_dims, output_size_data, + output_size_dims, output_data, output_dims, + /*align_corners=*/false); +} + +inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, uint8* output_data, + const Dims<4>& output_dims) { + ResizeBilinear(input_data, input_dims, output_size_data, + output_size_dims, output_data, output_dims, + /*align_corners=*/false); } template @@ -3506,8 +3430,6 @@ inline void Exp(const T* input_data, const size_t num_elements, } // A generic reduce method that can be used for reduce_sum, reduce_mean, etc. -// It takes a reducer function as input and returns false when numeric overflow -// is detected. // This method iterates through input data and reduce elements along the // dimensions given in axis. template @@ -3515,8 +3437,7 @@ inline bool Reduce(const In* input_data, const int* input_dims, const int* output_dims, const int input_num_dims, const int output_num_dims, const int* axis, const int num_axis, int* input_iter, - Out reducer(Out current, const In in, bool* overflow), - Out* output_data) { + Out reducer(Out current, const In in), Out* output_data) { // Reset input iterator. TFLITE_DCHECK(input_num_dims > 0); for (int idx = 0; idx < input_num_dims; ++idx) { @@ -3528,10 +3449,8 @@ inline bool Reduce(const In* input_data, const int* input_dims, ReducedOutputOffset(input_num_dims, input_dims, input_iter, 0, nullptr); size_t output_offset = ReducedOutputOffset(input_num_dims, input_dims, input_iter, num_axis, axis); - bool overflow = false; - output_data[output_offset] = reducer(output_data[output_offset], - input_data[input_offset], &overflow); - if (overflow) return false; + output_data[output_offset] = + reducer(output_data[output_offset], input_data[input_offset]); } while (NextIndex(input_num_dims, input_dims, input_iter)); return true; } @@ -3566,7 +3485,7 @@ inline bool ReduceSumImpl(const In* input_data, const int* input_dims, const int output_num_dims, const int* axis, const int num_axis, int* input_iter, Out* output_data) { - auto reducer = [](Out current, const In in, bool* overflow) -> Out { + auto reducer = [](Out current, const In in) -> Out { const Out actual_in = static_cast(in); return current + actual_in; }; @@ -3575,6 +3494,39 @@ inline bool ReduceSumImpl(const In* input_data, const int* input_dims, output_data); } +// Computes the sum of elements across dimensions given in axis. +template +inline bool Sum(const T* input_data, const int* input_dims, + const int input_num_dims, T* output_data, + const int* output_dims, const int output_num_dims, + const int* axis, const int num_axis_dimensions, bool keep_dims, + int* temp_index, int* resolved_axis) { + // Reset output data. + size_t num_outputs = 1; + for (int idx = 0; idx < output_num_dims; ++idx) { + size_t current = static_cast(output_dims[idx]); + // Overflow prevention. + if (num_outputs > std::numeric_limits::max() / current) { + return false; + } + num_outputs *= current; + } + for (size_t idx = 0; idx < num_outputs; ++idx) { + output_data[idx] = T(); + } + + // Resolve axis. + int num_resolved_axis = 0; + if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis, + &num_resolved_axis)) { + return false; + } + + return ReduceSumImpl(input_data, input_dims, output_dims, + input_num_dims, output_num_dims, resolved_axis, + num_resolved_axis, temp_index, output_data); +} + // Computes the mean of elements across dimensions given in axis. // It does so in two stages, first calculates the sum of elements along the axis // then divides it by the number of element in axis. @@ -3777,7 +3729,7 @@ void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, template void Transpose(const T* input, const Dims<4>& input_dims, T* output, - const Dims<4>& output_dims, int* permuted_axes) { + const Dims<4>& output_dims, const int* permuted_axes) { int out_sizes[4]; // Compute the inverse permutation array so we can do an output centered // transpose. Also, check to make sure output_dims is matching input_dims. @@ -3808,10 +3760,11 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, const float* filter_data, const Dims<4>& filter_dims, int stride_width, int stride_height, int pad_width, int pad_height, float* output_data, - const Dims<4>& output_dims) { + const Dims<4>& output_dims, float* /*im2col_data*/, + const Dims<4>& /*im2col_dims*/) { const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3); - const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); + const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0); const int input_height = ArraySize(input_dims, 2); const int input_width = ArraySize(input_dims, 1); const int filter_height = ArraySize(filter_dims, 2); @@ -3826,7 +3779,8 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, // computing their influence on the output, rather than looping through the // output elements in the typical "gather" access pattern of a conv. We // therefore must initialize the output array to zero. - for (int i = 0; i < FlatSize(output_dims); i++) { + const int num_elements = FlatSize(output_dims); + for (int i = 0; i < num_elements; i++) { output_data[i] = 0.0f; } @@ -3851,8 +3805,8 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, float input_value = input_data[Offset(input_dims, in_channel, in_x, in_y, batch)]; float filter_value = - filter_data[Offset(filter_dims, out_channel, filter_x, - filter_y, in_channel)]; + filter_data[Offset(filter_dims, in_channel, filter_x, + filter_y, out_channel)]; output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] += input_value * filter_value; } @@ -3865,6 +3819,16 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, } } +template +inline bool EqualFn(T lhs, T rhs) { + return lhs == rhs; +} + +template +inline bool NotEqualFn(T lhs, T rhs) { + return lhs != rhs; +} + template inline bool GreaterFn(T lhs, T rhs) { return lhs > rhs; @@ -4028,6 +3992,8 @@ inline void BroadcastComparison(int left_shift, const T* input1_data, input2_offset, input2_multiplier, \ input2_shift, output_data, output_dims); \ } +TFLITE_COMPARISON_OP(Equal); +TFLITE_COMPARISON_OP(NotEqual); TFLITE_COMPARISON_OP(Greater); TFLITE_COMPARISON_OP(GreaterEqual); TFLITE_COMPARISON_OP(Less); diff --git a/tensorflow/contrib/lite/kernels/internal/resize_bilinear_float_test.cc b/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc similarity index 60% rename from tensorflow/contrib/lite/kernels/internal/resize_bilinear_float_test.cc rename to tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc index c1c50dff4d2a966bff70853701334f599ee03849..3d8765f11b2941ef5871c7db8e3582e506713aa6 100644 --- a/tensorflow/contrib/lite/kernels/internal/resize_bilinear_float_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc @@ -24,9 +24,10 @@ limitations under the License. namespace tflite { namespace { +template void TestOneResizeBilinear(int batch, int depth, int input_width, int input_height, int output_width, - int output_height) { + int output_height, float error_threshold) { Dims<4> input_dims_inference = MakeDimsForInference(depth, input_width, input_height, batch); Dims<4> output_dims_inference = @@ -36,14 +37,15 @@ void TestOneResizeBilinear(int batch, int depth, int input_width, const int output_buffer_size = RequiredBufferSizeForDims(output_dims_inference); - std::vector input_data(input_buffer_size, 0); - std::vector reference_output_data(output_buffer_size, 0); + std::vector input_data(input_buffer_size, 0); + std::vector reference_output_data(output_buffer_size, 0); // Initialize the output data with something other than zero, so we can catch // issue with kernels failing to initialize the output. - std::vector output_data(output_buffer_size, 3.1415); + std::vector output_data(output_buffer_size, 3); - const float input_amplitude = 1.f; - FillRandom(&input_data, -input_amplitude, input_amplitude); + const T min_amplitude = static_cast(0); + const T max_amplitude = static_cast(255); + FillRandom(&input_data, min_amplitude, max_amplitude); Dims<4> output_size_dims = MakeDimsForInference(2, 1, 1, 1); std::vector output_size_data = {output_height, output_width}; @@ -58,14 +60,46 @@ void TestOneResizeBilinear(int batch, int depth, int input_width, double sum_diff = 0; float max_abs_val = 0; for (int i = 0; i < output_buffer_size; i++) { - sum_diff += std::abs(output_data[i] - reference_output_data[i]); - max_abs_val = std::max(max_abs_val, std::abs(reference_output_data[i])); + sum_diff += std::abs(static_cast(output_data[i]) - + static_cast(reference_output_data[i])); + max_abs_val = std::max( + max_abs_val, std::abs(static_cast(reference_output_data[i]))); } if (sum_diff != 0.f) { const float mean_diff = static_cast(sum_diff / output_buffer_size); const float relative_error = std::abs(mean_diff) / max_abs_val; - ASSERT_LT(relative_error, 1e-5f); + ASSERT_LT(relative_error, error_threshold); + } +} + +TEST(ResizeBilinear, TestResizeBilinear8Bit) { + const int kTestsToRun = 100 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int output_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int output_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + + TestOneResizeBilinear(batch, depth, input_width, input_height, + output_width, output_height, 0.025); + } +} + +TEST(ResizeBilinear2x2, TestResizeBilinear8Bit) { + const int kTestsToRun = 100 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int output_width = input_width * 2; + const int output_height = input_height * 2; + + TestOneResizeBilinear(batch, depth, input_width, input_height, + output_width, output_height, 1e-5); } } @@ -79,8 +113,8 @@ TEST(ResizeBilinear, TestResizeBilinear) { const int output_width = ExponentialRandomPositiveInt(0.9f, 20, 200); const int output_height = ExponentialRandomPositiveInt(0.9f, 20, 200); - TestOneResizeBilinear(batch, depth, input_width, input_height, output_width, - output_height); + TestOneResizeBilinear(batch, depth, input_width, input_height, + output_width, output_height, 1e-5); } } @@ -94,8 +128,8 @@ TEST(ResizeBilinear2x2, TestResizeBilinear) { const int output_width = input_width * 2; const int output_height = input_height * 2; - TestOneResizeBilinear(batch, depth, input_width, input_height, output_width, - output_height); + TestOneResizeBilinear(batch, depth, input_width, input_height, + output_width, output_height, 1e-5); } } } // namespace diff --git a/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc index d781a7b642036f3c5ddaa366f257fe26511c83c3..a7dad3c14e60fac9da9c0bcfd5d1d4c8f10b71c7 100644 --- a/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc @@ -32,19 +32,21 @@ namespace tflite { namespace { void RunSoftmaxFloatReference(const uint8* input_data, - const Dims<4>& dims_common, int32 input_offset, - const double input_scale, int stride, float beta, + const RuntimeShape& shape_common, + int32 input_offset, const double input_scale, + int stride, float beta, uint8* reference_output_data) { - const int ref_buffer_size = RequiredBufferSizeForDims(dims_common); + const int ref_buffer_size = shape_common.FlatSize(); std::vector reference_dequant_data(ref_buffer_size); std::vector reference_output_float_data(ref_buffer_size); // Reference data generated via Dequant of input into float, and then applying // float Softmax. - reference_ops::Dequantize(input_data, dims_common, input_offset, input_scale, - reference_dequant_data.data(), dims_common); - optimized_ops::Softmax(reference_dequant_data.data(), dims_common, beta, - reference_output_float_data.data(), dims_common); + reference_ops::Dequantize( + input_data, ToRuntimeDims(shape_common), input_offset, input_scale, + reference_dequant_data.data(), ToRuntimeDims(shape_common)); + optimized_ops::Softmax(reference_dequant_data.data(), shape_common, beta, + reference_output_float_data.data(), shape_common); // Work with quantized scaling for Softmax, under which 256 represents 1, but // we limit this to 255. for (int i = 0; i < ref_buffer_size; i++) { @@ -55,9 +57,9 @@ void RunSoftmaxFloatReference(const uint8* input_data, } void CheckOutputData(const uint8* test_output, const uint8* reference_output, - const Dims<4>& dims_common, const string& check_label, - bool be_exacting) { - const int buffer_size = RequiredBufferSizeForDims(dims_common); + const RuntimeShape& shape_common, + const string& check_label, bool be_exacting) { + const int buffer_size = shape_common.FlatSize(); // While calculating some metrics in floating point, we work with quantized // scaling. std::vector diff(buffer_size); @@ -91,15 +93,15 @@ void CheckOutputData(const uint8* test_output, const uint8* reference_output, // Runs the Softmax and compares against the float reference implementation and // the quantized reference implementation. -void RunOneSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common, - int32 input_offset, const double input_scale, int stride, - float beta) { - const int buffer_size = RequiredBufferSizeForDims(dims_common); +void RunOneSoftmaxTest(const uint8* input_data, + const RuntimeShape& shape_common, int32 input_offset, + const double input_scale, int stride, float beta) { + const int buffer_size = shape_common.FlatSize(); std::vector optimized_softmax_output(buffer_size); std::vector reference_float_softmax_output(buffer_size); std::vector reference_quant_softmax_output(buffer_size); - RunSoftmaxFloatReference(input_data, dims_common, input_offset, input_scale, + RunSoftmaxFloatReference(input_data, shape_common, input_offset, input_scale, stride, beta, reference_float_softmax_output.data()); int32 input_beta_multiplier; @@ -113,21 +115,21 @@ void RunOneSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common, const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits, input_beta_left_shift); - optimized_ops::Softmax(input_data, dims_common, input_beta_multiplier, + optimized_ops::Softmax(input_data, shape_common, input_beta_multiplier, input_beta_left_shift, diff_min, - optimized_softmax_output.data(), dims_common); - reference_ops::Softmax(input_data, dims_common, input_beta_multiplier, + optimized_softmax_output.data(), shape_common); + reference_ops::Softmax(input_data, shape_common, input_beta_multiplier, input_beta_left_shift, diff_min, - reference_quant_softmax_output.data(), dims_common); + reference_quant_softmax_output.data(), shape_common); CheckOutputData(optimized_softmax_output.data(), - reference_float_softmax_output.data(), dims_common, + reference_float_softmax_output.data(), shape_common, "Optimized vs float reference", false); CheckOutputData(optimized_softmax_output.data(), - reference_quant_softmax_output.data(), dims_common, + reference_quant_softmax_output.data(), shape_common, "Optimized vs quant reference", true); CheckOutputData(reference_quant_softmax_output.data(), - reference_float_softmax_output.data(), dims_common, + reference_float_softmax_output.data(), shape_common, "Quant reference vs float reference", false); } @@ -150,13 +152,13 @@ bool TryOneUniformSoftmax() { const int32 input_offset = UniformRandomInt(-256, 0); const float beta = 1.0f + ExponentialRandomPositiveFloat(0.9f, 2, 10); - Dims<4> dims_common = - MakeDimsForInference(input_depth, input_width, input_height, batch); - const int buffer_size = RequiredBufferSizeForDims(dims_common); + auto shape_common = + RuntimeShape({batch, input_height, input_width, input_depth}); + const int buffer_size = shape_common.FlatSize(); std::vector input_data(buffer_size); FillRandom(&input_data); - RunOneSoftmaxTest(input_data.data(), dims_common, input_offset, input_scale, + RunOneSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale, stride, beta); return true; } @@ -188,14 +190,14 @@ bool TryOneSkyscraperSoftmax(bool small_depth) { const int middle_min = UniformRandomInt(0, 255); const int sides_max = UniformRandomInt(0, middle_min); - Dims<4> dims_common = - MakeDimsForInference(input_depth, input_width, input_height, batch); - const int buffer_size = RequiredBufferSizeForDims(dims_common); + auto shape_common = + RuntimeShape({batch, input_height, input_width, input_depth}); + const int buffer_size = shape_common.FlatSize(); std::vector input_data(buffer_size); FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min, sides_max); - RunOneSoftmaxTest(input_data.data(), dims_common, input_offset, input_scale, + RunOneSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale, stride, beta); return true; } diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h index ce887cea8b794b4b0cfd31722581cf9327be625e..518bee1c6369d3ce93d1b98e19dba7615b5844dc 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor.h @@ -34,6 +34,11 @@ inline uint8_t* GetTensorData(TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.uint8 : nullptr; } +template <> +inline int16_t* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i16 : nullptr; +} + template <> inline int32_t* GetTensorData(TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.i32 : nullptr; @@ -62,6 +67,11 @@ inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.uint8 : nullptr; } +template <> +inline const int16_t* GetTensorData(const TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i16 : nullptr; +} + template <> inline const int32_t* GetTensorData(const TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.i32 : nullptr; @@ -114,6 +124,19 @@ inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) { return GetTensorDims(dims->data, dims->size); } +inline RuntimeShape GetTensorShape(std::vector data) { + return RuntimeShape(data.size(), data.data()); +} + +inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) { + if (tensor == nullptr) { + return RuntimeShape(); + } + + auto* dims = tensor->dims; + return RuntimeShape(dims->size, dims->data); +} + // A list of tensors in a format that can be used by kernels like split and // concatenation. template diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index 0c7fb7a76a5075652e705e65f5379596dfa77c78..707d2d261a4c4b7066d7dd34a26132a4d2f1722f 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -65,6 +65,10 @@ class RuntimeShape { ReplaceWith(dimensions_count, dims_data); } + RuntimeShape(const std::initializer_list init_list) : size_(0) { + BuildFrom(init_list); + } + ~RuntimeShape() { if (size_ > kMaxSmallSize) { delete[] dims_pointer_; @@ -121,6 +125,10 @@ class RuntimeShape { } } + inline void BuildFrom(const std::initializer_list init_list) { + BuildFrom>(init_list); + } + // Returns the total count of elements, that is the size when flattened into a // vector. inline int FlatSize() const { @@ -142,6 +150,22 @@ class RuntimeShape { }; }; +// Converts inference-style shape to legacy tflite::Dims<4>. +inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) { + tflite::Dims<4> result; + const int dimensions_count = array_shape.DimensionsCount(); + TFLITE_CHECK_LE(dimensions_count, 4); + int cum_prod = 1; + for (int i = 0; i < 4; i++) { + const int new_dim = + (i < dimensions_count) ? array_shape.Dims(dimensions_count - 1 - i) : 1; + result.sizes[i] = new_dim; + result.strides[i] = cum_prod; + cum_prod *= new_dim; + } + return result; +} + // Gets next index to iterate through a multidimensional array. inline bool NextIndex(const int num_dims, const int* dims, int* current) { TFLITE_DCHECK_GT(num_dims, 0); @@ -194,6 +218,15 @@ inline size_t ReducedOutputOffset(const int num_dims, const int* dims, return offset; } +inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) { + TFLITE_DCHECK(i0 >= 0 && i0 < shape.Dims(0)); + TFLITE_DCHECK(i1 >= 0 && i1 < shape.Dims(1)); + TFLITE_DCHECK(i2 >= 0 && i2 < shape.Dims(2)); + TFLITE_DCHECK(i3 >= 0 && i3 < shape.Dims(3)); + const int* dims_data = shape.DimsData(); + return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3; +} + inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) { TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]); TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]); @@ -208,6 +241,9 @@ inline int Offset(const Dims<4>& dims, int* index) { } // Get array size, DCHECKing that the dim index is in range. +// +// Note that this will be phased out with Dims<4>, since RuntimeShape::Dims() +// already performs this check. template int ArraySize(const Dims& array, int index) { TFLITE_DCHECK(index >= 0 && index < N); @@ -229,6 +265,21 @@ int MatchingArraySize(const ArrayType1& array1, int index1, return MatchingArraySize(array1, index1, args...); } +// Get common shape dim, DCHECKing that they all agree. +inline int MatchingDim(const RuntimeShape& shape1, int index1, + const RuntimeShape& shape2, int index2) { + TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2)); + return shape1.Dims(index1); +} + +template +int MatchingDim(const RuntimeShape& shape1, int index1, + const RuntimeShape& shape2, int index2, Args... args) { + TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2)); + return MatchingDim(shape1, index1, args...); +} + +// Will be phased out with Dims<4>, replaced by RuntimeShape::FlatSize(). template inline int FlatSize(const Dims& dims) { int flat_size = 1; @@ -243,6 +294,50 @@ inline int RequiredBufferSizeForDims(const Dims<4>& dims) { return FlatSize(dims); } +// Flat size calculation, checking that dimensions match with one or more other +// arrays. +inline int MatchingFlatSize(const RuntimeShape& shape, + const RuntimeShape& check_shape_0) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + return shape.FlatSize(); +} + +inline int MatchingFlatSize(const RuntimeShape& shape, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + return MatchingFlatSize(shape, check_shape_1); +} + +inline int MatchingFlatSize(const RuntimeShape& shape, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1, + const RuntimeShape& check_shape_2) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + return MatchingFlatSize(shape, check_shape_1, check_shape_2); +} + +inline int MatchingFlatSize(const RuntimeShape& shape, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1, + const RuntimeShape& check_shape_2, + const RuntimeShape& check_shape_3) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + return MatchingFlatSize(shape, check_shape_1, check_shape_2, check_shape_3); +} + // Flat size calculation, checking that dimensions match with one or more other // arrays. template @@ -269,7 +364,7 @@ inline int MatchingFlatSize(const Dims& dims, const Dims& check_dims_0, for (int i = 0; i < N; ++i) { TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i)); } - return FlatSize(dims, check_dims_1, check_dims_2); + return MatchingFlatSize(dims, check_dims_1, check_dims_2); } template @@ -280,7 +375,7 @@ inline int MatchingFlatSize(const Dims& dims, const Dims& check_dims_0, for (int i = 0; i < N; ++i) { TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i)); } - return FlatSize(dims, check_dims_1, check_dims_2, check_dims_3); + return MatchingFlatSize(dims, check_dims_1, check_dims_2, check_dims_3); } // Data is required to be contiguous, and so many operators can use either the @@ -348,6 +443,72 @@ inline int MatchingFlatSizeSkipDim(const Dims& dims, int skip_dim, check_dims_3); } +// Data is required to be contiguous, and so many operators can use either the +// full array flat size or the flat size with one dimension skipped (commonly +// the depth). +inline int FlatSizeSkipDim(const RuntimeShape& shape, int skip_dim) { + const int dims_count = shape.DimensionsCount(); + TFLITE_DCHECK(skip_dim >= 0 && skip_dim < dims_count); + const auto* dims_data = shape.DimsData(); + int flat_size = 1; + for (int i = 0; i < dims_count; ++i) { + flat_size *= (i == skip_dim) ? 1 : dims_data[i]; + } + return flat_size; +} + +// A combination of MatchingFlatSize() and FlatSizeSkipDim(). +inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim, + const RuntimeShape& check_shape_0) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + if (i != skip_dim) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + } + return FlatSizeSkipDim(shape, skip_dim); +} + +inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + if (i != skip_dim) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + } + return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1); +} + +inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1, + const RuntimeShape& check_shape_2) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + if (i != skip_dim) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + } + return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2); +} + +inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim, + const RuntimeShape& check_shape_0, + const RuntimeShape& check_shape_1, + const RuntimeShape& check_shape_2, + const RuntimeShape& check_shape_3) { + const int dims_count = shape.DimensionsCount(); + for (int i = 0; i < dims_count; ++i) { + if (i != skip_dim) { + TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i)); + } + } + return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2, + check_shape_3); +} + template bool IsPackedWithoutStrides(const Dims& dims) { int expected_stride = 1; diff --git a/tensorflow/contrib/lite/kernels/kernel_util.cc b/tensorflow/contrib/lite/kernels/kernel_util.cc index 184028427fb193aa99cf155961c16eda1298e326..fdf9856912b9a0f4b6acf81db6ecb5f9c9385f0b 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util.cc +++ b/tensorflow/contrib/lite/kernels/kernel_util.cc @@ -43,12 +43,11 @@ TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context, return kTfLiteOk; } -void CalculateActivationRangeUint8(TfLiteFusedActivation activation, - TfLiteTensor* output, int32_t* act_min, - int32_t* act_max) { - const int32_t qmin = std::numeric_limits::min(); - const int32_t qmax = std::numeric_limits::max(); - +namespace { +void CalculateActivationRangeQuantizedImpl(TfLiteFusedActivation activation, + int32_t qmin, int32_t qmax, + TfLiteTensor* output, + int32_t* act_min, int32_t* act_max) { const auto scale = output->params.scale; const auto zero_point = output->params.zero_point; @@ -70,6 +69,39 @@ void CalculateActivationRangeUint8(TfLiteFusedActivation activation, *act_max = qmax; } } +} // namespace + +TfLiteStatus CalculateActivationRangeQuantized(TfLiteContext* context, + TfLiteFusedActivation activation, + TfLiteTensor* output, + int32_t* act_min, + int32_t* act_max) { + int32_t qmin = 0; + int32_t qmax = 0; + if (output->type == kTfLiteUInt8) { + qmin = std::numeric_limits::min(); + qmax = std::numeric_limits::max(); + } else if (output->type == kTfLiteInt16) { + qmin = std::numeric_limits::min(); + qmax = std::numeric_limits::max(); + } else { + TF_LITE_ENSURE(context, false); + } + + CalculateActivationRangeQuantizedImpl(activation, qmin, qmax, output, act_min, + act_max); + return kTfLiteOk; +} + +void CalculateActivationRangeUint8(TfLiteFusedActivation activation, + TfLiteTensor* output, int32_t* act_min, + int32_t* act_max) { + const int32_t qmin = std::numeric_limits::min(); + const int32_t qmax = std::numeric_limits::max(); + + CalculateActivationRangeQuantizedImpl(activation, qmin, qmax, output, act_min, + act_max); +} void CalculateActivationRangeFloat(TfLiteFusedActivation activation, float* activation_min, diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h index 82cded36f2ed2777daccafee5890f47c0d7254e8..20058a5f6971ffc6c1763b0e98cd4a91fe7b6e44 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util.h +++ b/tensorflow/contrib/lite/kernels/kernel_util.h @@ -88,6 +88,11 @@ TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context, // Calculates the useful range of an activation layer given its activation // tensor. +TfLiteStatus CalculateActivationRangeQuantized(TfLiteContext* context, + TfLiteFusedActivation activation, + TfLiteTensor* output, + int32_t* act_min, + int32_t* act_max); void CalculateActivationRangeUint8(TfLiteFusedActivation activation, TfLiteTensor* output, int32_t* act_min, int32_t* act_max); diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc index 3205c1cc52724207904621a5870636841ef379fe..a7b54c6b842332feb2d9e7179e79ae054bd23bb9 100644 --- a/tensorflow/contrib/lite/kernels/l2norm.cc +++ b/tensorflow/contrib/lite/kernels/l2norm.cc @@ -70,8 +70,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (output->type == kTfLiteFloat32) { #define TF_LITE_L2NORM(type) \ type::L2Normalization( \ - GetTensorData(input), GetTensorDims(input), \ - GetTensorData(output), GetTensorDims(output)) + GetTensorData(input), GetTensorShape(input), \ + GetTensorData(output), GetTensorShape(output)) if (kernel_type == kReference) { TF_LITE_L2NORM(reference_ops); @@ -81,10 +81,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } #undef TF_LITE_L2NORM } else if (output->type == kTfLiteUInt8) { -#define TF_LITE_L2NORM(type) \ - type::L2Normalization(GetTensorData(input), GetTensorDims(input), \ - input->params.zero_point, \ - GetTensorData(output), GetTensorDims(output)) +#define TF_LITE_L2NORM(type) \ + type::L2Normalization(GetTensorData(input), GetTensorShape(input), \ + input->params.zero_point, \ + GetTensorData(output), GetTensorShape(output)) if (kernel_type == kReference) { TF_LITE_L2NORM(reference_ops); diff --git a/tensorflow/contrib/lite/kernels/log_softmax_test.cc b/tensorflow/contrib/lite/kernels/log_softmax_test.cc index 62820a2f5113cb6ae252386aaf3842135383b79f..9a8d35e82cbc3a7e55246e6c06599b2838d1ee67 100644 --- a/tensorflow/contrib/lite/kernels/log_softmax_test.cc +++ b/tensorflow/contrib/lite/kernels/log_softmax_test.cc @@ -90,10 +90,9 @@ TEST(LogSoftmaxOpTest, CompareWithTFmini) { m.Invoke(); std::unique_ptr output_buffer(new float[input_size * batch_size]); - static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size}, - {1, 0, 0, input_size}}; - tflite::reference_ops::LogSoftmax(input_buffer, input_dims, - output_buffer.get(), input_dims); + auto input_shape = RuntimeShape({batch_size, 1, 1, input_size}); + tflite::reference_ops::LogSoftmax(input_buffer, input_shape, + output_buffer.get(), input_shape); std::vector expected; expected.insert(expected.end(), output_buffer.get(), diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc index 9aae3e571b33754703b353545e418d3485b7433c..1dda97c101bb12692690f4d3041d1176b8c54392 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -37,14 +37,17 @@ namespace builtin { namespace lstm { struct OpData { - // Which kernel type to use. Full kernel (18-inputs) or basic kernel - // (5-inputs). + // Which kernel type to use. Full kernel (18 or 20 inputs) or basic kernel + // (5 inputs). TfLiteLSTMKernelType kernel_type; - // Only used by full kernel. + + // These fields are only used by full kernel. + int activation_state_tensor_index; + int cell_state_tensor_index; int scratch_tensor_index; }; -// For full inputs kernel (18-inputs). +// For full inputs kernel (18 or 20 inputs). namespace full { // Input Tensors of size {n_batch, n_input} @@ -78,7 +81,16 @@ constexpr int kProjectionWeightsTensor = 16; // Optional // Projection bias tensor of size {n_output} constexpr int kProjectionBiasTensor = 17; // Optional +// If the node has 20 inputs, the following 2 tensors are used as state tensors. +// These are defined as variable tensors, and will be modified by this op. +constexpr int kInputActivationStateTensor = 18; +constexpr int kInputCellStateTensor = 19; + // Output tensors. +// * If the node has 18 inputs, these 2 tensors are used as state tensors. +// * If the node has 20 inputs, these 2 tensors are ignored. +// TODO(ycling): Make the 2 output state tensors optional, and propagate the +// state to output tensors when the 2 tensors present. constexpr int kOutputStateTensor = 0; constexpr int kCellStateTensor = 1; constexpr int kOutputTensor = 2; @@ -86,7 +98,8 @@ constexpr int kOutputTensor = 2; void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* op_data = new OpData; op_data->kernel_type = kTfLiteLSTMFullKernel; - context->AddTensors(context, 1, &op_data->scratch_tensor_index); + context->AddTensors(context, /*tensors_to_add=*/7, + &op_data->scratch_tensor_index); return op_data; } @@ -94,7 +107,7 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TfLiteNode* node, int n_input, int n_output, int n_cell) { - auto* params = reinterpret_cast(node->builtin_data); + const auto* params = reinterpret_cast(node->builtin_data); // Making sure clipping parameters have valid values. // == 0 means no clipping @@ -104,7 +117,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); - if (input_to_input_weights) { + if (input_to_input_weights != nullptr) { TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); @@ -124,7 +137,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); - if (recurrent_to_input_weights) { + if (recurrent_to_input_weights != nullptr) { TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0], n_cell); @@ -214,7 +227,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* projection_weights = GetOptionalInputTensor(context, node, kProjectionWeightsTensor); - if (projection_weights) { + if (projection_weights != nullptr) { TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output); TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); @@ -222,7 +235,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* projection_bias = GetOptionalInputTensor(context, node, kProjectionBiasTensor); - if (projection_bias) { + if (projection_bias != nullptr) { TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output); } @@ -245,13 +258,35 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { OpData* op_data = reinterpret_cast(node->user_data); - // Check we have all the inputs and outputs we need. - TF_LITE_ENSURE_EQ(context, node->inputs->size, 18); TF_LITE_ENSURE_EQ(context, node->outputs->size, 3); + // True if the node is using input variable state tensors. It means: + // * The state tensors are defined as inputs. In this case it would be the + // 19th and 20th input tensors. + // * Otherwise, the output tensors are used to store states. + bool use_input_variable_states; + if (node->inputs->size == 20) { + use_input_variable_states = true; + op_data->activation_state_tensor_index = + node->inputs->data[kInputActivationStateTensor]; + op_data->cell_state_tensor_index = + node->inputs->data[kInputCellStateTensor]; + } else if (node->inputs->size == 18) { + use_input_variable_states = false; + op_data->activation_state_tensor_index = + node->outputs->data[kOutputStateTensor]; + op_data->cell_state_tensor_index = node->outputs->data[kCellStateTensor]; + } else { + context->ReportError( + context, "The LSTM Full kernel expects 18 or 20 inputs. Got %d inputs", + node->inputs->size); + return kTfLiteError; + } + // Inferring batch size, number of outputs and number of cells from the // input tensors. const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE(context, input->dims->size > 1); const int n_batch = input->dims->data[0]; const int n_input = input->dims->data[1]; @@ -272,110 +307,185 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Check that input tensor dimensions matches with each other. CheckInputTensorDimensions(context, node, n_input, n_output, n_cell); - // Get the pointer to output, output_state and cell_state tensors. + // Get the pointer to output, activation_state and cell_state tensors. TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); - TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); - // Resize the output, output_state and cell_state tensors. + TfLiteTensor* activation_state = + &context->tensors[op_data->activation_state_tensor_index]; + TfLiteTensor* cell_state = + &context->tensors[op_data->cell_state_tensor_index]; + + if (use_input_variable_states) { + // Check the shape of input state tensors. + // These tensor may be 1D or 2D. It's fine as long as the total size is + // correct. + TF_LITE_ENSURE_EQ(context, NumElements(activation_state), + n_batch * n_output); + TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell); + } else { + // If the state tensors are outputs, this function takes the + // responsibility to resize the state tensors. + TfLiteIntArray* activation_state_size = TfLiteIntArrayCreate(2); + activation_state_size->data[0] = n_batch; + activation_state_size->data[1] = n_output; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_state, + activation_state_size)); + + TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2); + cell_size->data[0] = n_batch; + cell_size->data[1] = n_cell; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, cell_state, cell_size)); + // Mark state tensors as persistent tensors. + activation_state->allocation_type = kTfLiteArenaRwPersistent; + cell_state->allocation_type = kTfLiteArenaRwPersistent; + } + + // Resize the output tensors. TfLiteIntArray* output_size = TfLiteIntArrayCreate(2); output_size->data[0] = n_batch; output_size->data[1] = n_output; TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, output_size)); - TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2); - output_state_size->data[0] = n_batch; - output_state_size->data[1] = n_output; - TF_LITE_ENSURE_OK( - context, context->ResizeTensor(context, output_state, output_state_size)); - - TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2); - cell_size->data[0] = n_batch; - cell_size->data[1] = n_cell; - TF_LITE_ENSURE_OK(context, - context->ResizeTensor(context, cell_state, cell_size)); + // The weights are of consistent type, so it suffices to check one. + // TODO(mirkov): create a utility/macro for this check, so all Ops can use it. + const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 && + input->type == kTfLiteFloat32); - // Create a scratch buffer tensor. TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(1); + if (is_hybrid_op) { + node->temporaries = TfLiteIntArrayCreate(7); + } else { + node->temporaries = TfLiteIntArrayCreate(1); + } node->temporaries->data[0] = op_data->scratch_tensor_index; + + // Create a scratch buffer tensor. TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); scratch_buffer->type = input->type; scratch_buffer->allocation_type = kTfLiteArenaRw; - // Mark state tensors as persistent tensors. - output_state->allocation_type = kTfLiteArenaRwPersistent; - cell_state->allocation_type = kTfLiteArenaRwPersistent; - const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); const bool use_cifg = (input_to_input_weights == nullptr); + TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); + scratch_buffer_size->data[0] = n_batch; if (use_cifg) { - TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); - scratch_buffer_size->data[0] = n_batch; // Reserving space for Cell, Forget, Output gates scratch_buffer_size->data[1] = n_cell * 3; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, - scratch_buffer_size)); } else { - TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); - scratch_buffer_size->data[0] = n_batch; // Reserving space for Input, Cell, Forget, Output gates scratch_buffer_size->data[1] = n_cell * 4; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, - scratch_buffer_size)); + } + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, + scratch_buffer_size)); + + if (is_hybrid_op) { + // Allocate temporary tensors to store quantized values of input, + // activation_state and cell_state tensors. + node->temporaries->data[1] = op_data->scratch_tensor_index + 1; + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); + input_quantized->type = kTfLiteUInt8; + input_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { + TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, + input_quantized_size)); + } + node->temporaries->data[2] = op_data->scratch_tensor_index + 2; + TfLiteTensor* activation_state_quantized = + GetTemporary(context, node, /*index=*/2); + activation_state_quantized->type = kTfLiteUInt8; + activation_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(activation_state_quantized->dims, + activation_state->dims)) { + TfLiteIntArray* activation_state_quantized_size = + TfLiteIntArrayCopy(activation_state->dims); + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, activation_state_quantized, + activation_state_quantized_size)); + } + node->temporaries->data[3] = op_data->scratch_tensor_index + 3; + TfLiteTensor* cell_state_quantized = + GetTemporary(context, node, /*index=*/3); + cell_state_quantized->type = kTfLiteUInt8; + cell_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) { + TfLiteIntArray* cell_state_quantized_size = + TfLiteIntArrayCopy(cell_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, cell_state_quantized, + cell_state_quantized_size)); + } + + // Allocate temporary tensors to store scaling factors and product scaling + // factors. The latter is a convenience storage which allows to quantize + // a vector once (which produces the scaling factors) and multiply it with + // different matrices (which requires multiplying the scaling factors with + // the scaling factor of the matrix). + node->temporaries->data[4] = op_data->scratch_tensor_index + 4; + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } + node->temporaries->data[5] = op_data->scratch_tensor_index + 5; + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, /*index=*/5); + prod_scaling_factors->type = kTfLiteFloat32; + prod_scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1); + prod_scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(prod_scaling_factors->dims, + prod_scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, prod_scaling_factors, + prod_scaling_factors_size)); + } + + // Allocate a temporary tensor to store the recovered cell weights. Since + // this is used for diagonal matrices, only need to store n_cell values. + node->temporaries->data[6] = op_data->scratch_tensor_index + 6; + TfLiteTensor* recovered_cell_weights = + GetTemporary(context, node, /*index=*/6); + recovered_cell_weights->type = kTfLiteFloat32; + recovered_cell_weights->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1); + recovered_cell_weights_size->data[0] = n_cell; + if (!TfLiteIntArrayEqual(recovered_cell_weights->dims, + recovered_cell_weights_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, recovered_cell_weights, + recovered_cell_weights_size)); + } } return kTfLiteOk; } // The LSTM Op engine. -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - - const TfLiteTensor* input_to_input_weights = - GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); - const TfLiteTensor* input_to_forget_weights = - GetInput(context, node, kInputToForgetWeightsTensor); - const TfLiteTensor* input_to_cell_weights = - GetInput(context, node, kInputToCellWeightsTensor); - const TfLiteTensor* input_to_output_weights = - GetInput(context, node, kInputToOutputWeightsTensor); - - const TfLiteTensor* recurrent_to_input_weights = - GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); - const TfLiteTensor* recurrent_to_forget_weights = - GetInput(context, node, kRecurrentToForgetWeightsTensor); - const TfLiteTensor* recurrent_to_cell_weights = - GetInput(context, node, kRecurrentToCellWeightsTensor); - const TfLiteTensor* recurrent_to_output_weights = - GetInput(context, node, kRecurrentToOutputWeightsTensor); - - const TfLiteTensor* cell_to_input_weights = - GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); - const TfLiteTensor* cell_to_forget_weights = - GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); - const TfLiteTensor* cell_to_output_weights = - GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); - - const TfLiteTensor* input_gate_bias = - GetOptionalInputTensor(context, node, kInputGateBiasTensor); - const TfLiteTensor* forget_gate_bias = - GetInput(context, node, kForgetGateBiasTensor); - const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); - const TfLiteTensor* output_gate_bias = - GetInput(context, node, kOutputGateBiasTensor); - - const TfLiteTensor* projection_weights = - GetOptionalInputTensor(context, node, kProjectionWeightsTensor); - const TfLiteTensor* projection_bias = - GetOptionalInputTensor(context, node, kProjectionBiasTensor); - - TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); - TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - +TfLiteStatus EvalFloat( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, + TfLiteTensor* activation_state, TfLiteTensor* cell_state, + TfLiteTensor* output) { const int n_batch = input->dims->data[0]; const int n_input = input->dims->data[1]; // n_cell and n_output will be the same size when there is no projection. @@ -387,9 +497,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const bool use_cifg = (input_to_input_weights == nullptr); const bool use_peephole = (cell_to_output_weights != nullptr); - // Index the scratch buffers pointers to the global scratch buffer. - TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); - float* input_gate_scratch = nullptr; float* cell_scratch = nullptr; float* forget_gate_scratch = nullptr; @@ -438,7 +545,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const float* cell_bias_ptr = cell_bias->data.f; const float* output_gate_bias_ptr = output_gate_bias->data.f; - float* output_state_ptr = output_state->data.f; + float* activation_state_ptr = activation_state->data.f; float* cell_state_ptr = cell_state->data.f; float* output_ptr_batch = output->data.f; @@ -451,9 +558,267 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, projection_bias_ptr, params, n_batch, n_cell, n_input, n_output, - output_state_ptr, cell_state_ptr, input_gate_scratch, forget_gate_scratch, - cell_scratch, output_gate_scratch, output_ptr_batch); + activation_state_ptr, cell_state_ptr, input_gate_scratch, + forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch); + + return kTfLiteOk; +} +TfLiteStatus EvalHybrid( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, + TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, + TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, + TfLiteTensor* activation_state_quantized, + TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state, + TfLiteTensor* cell_state, TfLiteTensor* output) { + const int n_batch = input->dims->data[0]; + const int n_input = input->dims->data[1]; + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existence of only one to get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Check optional tensors, the respective pointers can be null. + int8_t* input_to_input_weights_ptr = nullptr; + float input_to_input_weights_scale = 1.0f; + int8_t* recurrent_to_input_weights_ptr = nullptr; + float recurrent_to_input_weights_scale = 1.0f; + float* input_gate_bias_ptr = nullptr; + if (!use_cifg) { + input_to_input_weights_ptr = + reinterpret_cast(input_to_input_weights->data.uint8); + recurrent_to_input_weights_ptr = + reinterpret_cast(recurrent_to_input_weights->data.uint8); + input_gate_bias_ptr = input_gate_bias->data.f; + input_to_input_weights_scale = input_to_input_weights->params.scale; + recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale; + } + + int8_t* cell_to_input_weights_ptr = nullptr; + int8_t* cell_to_forget_weights_ptr = nullptr; + int8_t* cell_to_output_weights_ptr = nullptr; + float cell_to_input_weights_scale = 1.0f; + float cell_to_forget_weights_scale = 1.0f; + float cell_to_output_weights_scale = 1.0f; + if (use_peephole) { + if (!use_cifg) { + cell_to_input_weights_ptr = + reinterpret_cast(cell_to_input_weights->data.uint8); + cell_to_input_weights_scale = cell_to_input_weights->params.scale; + } + cell_to_forget_weights_ptr = + reinterpret_cast(cell_to_forget_weights->data.uint8); + cell_to_output_weights_ptr = + reinterpret_cast(cell_to_output_weights->data.uint8); + cell_to_forget_weights_scale = cell_to_forget_weights->params.scale; + cell_to_output_weights_scale = cell_to_output_weights->params.scale; + } + + const int8_t* projection_weights_ptr = + (projection_weights == nullptr) + ? nullptr + : reinterpret_cast(projection_weights->data.uint8); + const float projection_weights_scale = + (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale; + const float* projection_bias_ptr = + (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + + // Required tensors, pointers are non-null. + const float* input_ptr_batch = input->data.f; + const int8_t* input_to_forget_weights_ptr = + reinterpret_cast(input_to_forget_weights->data.uint8); + const float input_to_forget_weights_scale = + input_to_forget_weights->params.scale; + const int8_t* input_to_cell_weights_ptr = + reinterpret_cast(input_to_cell_weights->data.uint8); + const float input_to_cell_weights_scale = input_to_cell_weights->params.scale; + const int8_t* input_to_output_weights_ptr = + reinterpret_cast(input_to_output_weights->data.uint8); + const float input_to_output_weights_scale = + input_to_output_weights->params.scale; + const int8_t* recurrent_to_forget_weights_ptr = + reinterpret_cast(recurrent_to_forget_weights->data.uint8); + const float recurrent_to_forget_weights_scale = + recurrent_to_forget_weights->params.scale; + const int8_t* recurrent_to_cell_weights_ptr = + reinterpret_cast(recurrent_to_cell_weights->data.uint8); + const float recurrent_to_cell_weights_scale = + recurrent_to_cell_weights->params.scale; + const int8_t* recurrent_to_output_weights_ptr = + reinterpret_cast(recurrent_to_output_weights->data.uint8); + const float recurrent_to_output_weights_scale = + recurrent_to_output_weights->params.scale; + const float* forget_gate_bias_ptr = forget_gate_bias->data.f; + const float* cell_bias_ptr = cell_bias->data.f; + const float* output_gate_bias_ptr = output_gate_bias->data.f; + + float* activation_state_ptr = activation_state->data.f; + float* cell_state_ptr = cell_state->data.f; + float* output_ptr_batch = output->data.f; + + // Temporary storage for quantized values and scaling factors. + int8_t* quantized_input_ptr = + reinterpret_cast(input_quantized->data.uint8); + int8_t* quantized_activation_state_ptr = + reinterpret_cast(activation_state_quantized->data.uint8); + int8_t* quantized_cell_state_ptr = + reinterpret_cast(cell_state_quantized->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; + float* prod_scaling_factors_ptr = prod_scaling_factors->data.f; + float* recovered_cell_weights_ptr = recovered_cell_weights->data.f; + + kernel_utils::LstmStep( + input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale, + input_to_forget_weights_ptr, input_to_forget_weights_scale, + input_to_cell_weights_ptr, input_to_cell_weights_scale, + input_to_output_weights_ptr, input_to_output_weights_scale, + recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale, + recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale, + recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale, + recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale, + cell_to_input_weights_ptr, cell_to_input_weights_scale, + cell_to_forget_weights_ptr, cell_to_forget_weights_scale, + cell_to_output_weights_ptr, cell_to_output_weights_scale, + input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, + output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale, + projection_bias_ptr, params, n_batch, n_cell, n_input, n_output, + input_gate_scratch, forget_gate_scratch, cell_scratch, + output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr, + recovered_cell_weights_ptr, quantized_input_ptr, + quantized_activation_state_ptr, quantized_cell_state_ptr, + activation_state_ptr, cell_state_ptr, output_ptr_batch); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const auto* params = reinterpret_cast(node->builtin_data); + OpData* op_data = reinterpret_cast(node->user_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + + const TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + const TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + const TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + const TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + + const TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + const TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + const TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + const TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + + const TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + const TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + const TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + + const TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + const TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + const TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + + const TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + const TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + + // Index the scratch buffers pointers to the global scratch buffer. + TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); + + TfLiteTensor* activation_state = + &context->tensors[op_data->activation_state_tensor_index]; + TfLiteTensor* cell_state = + &context->tensors[op_data->cell_state_tensor_index]; + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // TODO(mirkov): add a check that weights are all uint8s or all floats. + switch (input_to_output_weights->type) { + case kTfLiteFloat32: { + return EvalFloat(input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, + cell_to_output_weights, input_gate_bias, + forget_gate_bias, cell_bias, output_gate_bias, + projection_weights, projection_bias, params, + scratch_buffer, activation_state, cell_state, output); + } + case kTfLiteUInt8: { + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); + TfLiteTensor* activation_state_quantized = + GetTemporary(context, node, /*index=*/2); + TfLiteTensor* cell_state_quantized = + GetTemporary(context, node, /*index=*/3); + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4); + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, /*index=*/5); + TfLiteTensor* recovered_cell_weights = + GetTemporary(context, node, /*index=*/6); + return EvalHybrid( + input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, + input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, + projection_weights, projection_bias, params, scratch_buffer, + scaling_factors, prod_scaling_factors, recovered_cell_weights, + input_quantized, activation_state_quantized, cell_state_quantized, + activation_state, cell_state, output); + } + default: + context->ReportError(context, "Type %d is not currently supported.", + input_to_output_weights->type); + return kTfLiteError; + } return kTfLiteOk; } @@ -491,7 +856,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, node->inputs->size == kInputNum); TF_LITE_ENSURE(context, node->outputs->size == kOutputNum); - // Only Float32 is supportted currently. + // Only Float32 is supported currently. // TODO(ycling): Implement quantize uint8 support. for (int index = 0; index < node->inputs->size; ++index) { TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]]; diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc index d81220d8d30793616444c03e8647b0877a39a4d9..3f5c44a63ec328b23e11dc42428f7cd85a788509 100644 --- a/tensorflow/contrib/lite/kernels/lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/lstm_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ // Unit test for TFLite LSTM op. -#include #include #include @@ -35,7 +34,8 @@ class LSTMOpModel : public SingleOpModel { LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg, bool use_peephole, bool use_projection_weights, bool use_projection_bias, float cell_clip, float proj_clip, - const std::vector>& input_shapes) + const std::vector>& input_shapes, + const TensorType& weight_type = TensorType_FLOAT32) : n_batch_(n_batch), n_input_(n_input), n_cell_(n_cell), @@ -45,31 +45,31 @@ class LSTMOpModel : public SingleOpModel { if (use_cifg) { input_to_input_weights_ = AddNullInput(); } else { - input_to_input_weights_ = AddInput(TensorType_FLOAT32); + input_to_input_weights_ = AddInput(weight_type); } - input_to_forget_weights_ = AddInput(TensorType_FLOAT32); - input_to_cell_weights_ = AddInput(TensorType_FLOAT32); - input_to_output_weights_ = AddInput(TensorType_FLOAT32); + input_to_forget_weights_ = AddInput(weight_type); + input_to_cell_weights_ = AddInput(weight_type); + input_to_output_weights_ = AddInput(weight_type); if (use_cifg) { recurrent_to_input_weights_ = AddNullInput(); } else { - recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_input_weights_ = AddInput(weight_type); } - recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32); - recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32); - recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_forget_weights_ = AddInput(weight_type); + recurrent_to_cell_weights_ = AddInput(weight_type); + recurrent_to_output_weights_ = AddInput(weight_type); if (use_peephole) { if (use_cifg) { cell_to_input_weights_ = AddNullInput(); } else { - cell_to_input_weights_ = AddInput(TensorType_FLOAT32); + cell_to_input_weights_ = AddInput(weight_type); } - cell_to_forget_weights_ = AddInput(TensorType_FLOAT32); - cell_to_output_weights_ = AddInput(TensorType_FLOAT32); + cell_to_forget_weights_ = AddInput(weight_type); + cell_to_output_weights_ = AddInput(weight_type); } else { cell_to_input_weights_ = AddNullInput(); cell_to_forget_weights_ = AddNullInput(); @@ -86,7 +86,7 @@ class LSTMOpModel : public SingleOpModel { output_gate_bias_ = AddInput(TensorType_FLOAT32); if (use_projection_weights) { - projection_weights_ = AddInput(TensorType_FLOAT32); + projection_weights_ = AddInput(weight_type); if (use_projection_bias) { projection_bias_ = AddInput(TensorType_FLOAT32); } else { @@ -97,6 +97,12 @@ class LSTMOpModel : public SingleOpModel { projection_bias_ = AddNullInput(); } + // Adding the 2 input state tensors. + input_activation_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true); + input_cell_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true); + output_state_ = AddOutput(TensorType_FLOAT32); cell_state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); @@ -192,8 +198,9 @@ class LSTMOpModel : public SingleOpModel { zero_buffer.get() + zero_buffer_size); } - void SetInput(int offset, float* begin, float* end) { - PopulateTensor(input_, offset, begin, end); + void SetInput(int offset, const float* begin, const float* end) { + PopulateTensor(input_, offset, const_cast(begin), + const_cast(end)); } std::vector GetOutput() { return ExtractVector(output_); } @@ -203,7 +210,7 @@ class LSTMOpModel : public SingleOpModel { int num_cells() { return n_cell_; } int num_batches() { return n_batch_; } - private: + protected: int input_; int input_to_input_weights_; int input_to_forget_weights_; @@ -226,6 +233,8 @@ class LSTMOpModel : public SingleOpModel { int projection_weights_; int projection_bias_; + int input_activation_state_; + int input_cell_state_; int output_; int output_state_; @@ -237,7 +246,182 @@ class LSTMOpModel : public SingleOpModel { int n_output_; }; -TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { +class HybridLSTMOpModel : public LSTMOpModel { + public: + HybridLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, + bool use_cifg, bool use_peephole, + bool use_projection_weights, bool use_projection_bias, + float cell_clip, float proj_clip, + const std::vector>& input_shapes) + : LSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg, use_peephole, + use_projection_weights, use_projection_bias, cell_clip, + proj_clip, input_shapes, TensorType_UINT8) {} + + void SetInputToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_output_weights_, f); + } + + void SetProjectionWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(projection_weights_, f); + } +}; + +class BaseLstmTest : public ::testing::Test { + protected: + // Weights of the LSTM model. Some are optional. + std::initializer_list input_to_input_weights_; + std::initializer_list input_to_cell_weights_; + std::initializer_list input_to_forget_weights_; + std::initializer_list input_to_output_weights_; + std::initializer_list input_gate_bias_; + std::initializer_list cell_gate_bias_; + std::initializer_list forget_gate_bias_; + std::initializer_list output_gate_bias_; + std::initializer_list recurrent_to_input_weights_; + std::initializer_list recurrent_to_cell_weights_; + std::initializer_list recurrent_to_forget_weights_; + std::initializer_list recurrent_to_output_weights_; + std::initializer_list cell_to_input_weights_; + std::initializer_list cell_to_forget_weights_; + std::initializer_list cell_to_output_weights_; + std::initializer_list projection_weights_; + + // LSTM input is stored as num_batch x num_inputs vector. + std::vector> lstm_input_; + // LSTM output is stored as num_batch x num_outputs vector. + std::vector> lstm_golden_output_; + + // Compares output up to tolerance to the result of the lstm given the input. + void VerifyGoldens(const std::vector>& input, + const std::vector>& output, + LSTMOpModel* lstm, float tolerance = 1e-5) { + const int num_batches = input.size(); + EXPECT_GT(num_batches, 0); + const int num_inputs = lstm->num_inputs(); + EXPECT_GT(num_inputs, 0); + const int input_sequence_size = input[0].size() / num_inputs; + EXPECT_GT(input_sequence_size, 0); + for (int i = 0; i < input_sequence_size; ++i) { + for (int b = 0; b < num_batches; ++b) { + const float* batch_start = input[b].data() + i * num_inputs; + const float* batch_end = batch_start + num_inputs; + + lstm->SetInput(b * lstm->num_inputs(), batch_start, batch_end); + } + + lstm->Invoke(); + + const int num_outputs = lstm->num_outputs(); + std::vector expected; + for (int b = 0; b < num_batches; ++b) { + const float* golden_start_batch = output[b].data() + i * num_outputs; + const float* golden_end_batch = golden_start_batch + num_outputs; + expected.insert(expected.end(), golden_start_batch, golden_end_batch); + } + EXPECT_THAT(lstm->GetOutput(), + ElementsAreArray(ArrayFloatNear(expected, tolerance))); + for (int i = 0; i < num_outputs; ++i) { + std::cout << lstm->GetOutput()[i] << ", "; + } + std::cout << std::endl; + for (int i = 0; i < num_outputs; ++i) { + std::cout << expected[i] << ", "; + } + std::cout << std::endl; + } + } +}; + +class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589, + -0.34550029, 0.04266912, -0.15680569, + -0.34856534, 0.43890524}; + input_to_cell_weights_ = {-0.50013041, 0.1370284, 0.11810488, 0.2013163, + -0.20583314, 0.44344562, 0.22077113, -0.29909778}; + input_to_forget_weights_ = {0.09701663, 0.20334584, -0.50592935, + -0.31343272, -0.40032279, 0.44781327, + 0.01387155, -0.35593212}; + input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829, + 0.40525138, 0.44272184, 0.03897077, + -0.1556896, 0.19487578}; + input_gate_bias_ = {0., 0., 0., 0.}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_input_weights_ = { + -0.0063535, -0.2042388, 0.31454784, -0.35746509, + 0.28902304, 0.08183324, -0.16555229, 0.02286911, + -0.13566875, 0.03034258, 0.48091322, -0.12528998, + 0.24077177, -0.51332325, -0.33502164, 0.10629296}; + + recurrent_to_cell_weights_ = { + -0.3407414, 0.24443203, -0.2078532, 0.26320225, + 0.05695659, -0.00123841, -0.4744786, -0.35869038, + -0.06418842, -0.13502428, -0.501764, 0.22830659, + -0.46367589, 0.26016325, -0.03894562, -0.16368064}; + + recurrent_to_forget_weights_ = { + -0.48684245, -0.06655136, 0.42224967, 0.2112639, + 0.27654213, 0.20864892, -0.07646349, 0.45877004, + 0.00141793, -0.14609534, 0.36447752, 0.09196436, + 0.28053468, 0.01560611, -0.20127171, -0.01140004}; + + recurrent_to_output_weights_ = { + 0.43385774, -0.17194885, 0.2718237, 0.09215671, + 0.24107647, -0.39835793, 0.18212086, 0.01301402, + 0.48572797, -0.50656658, 0.20047462, -0.20607421, + -0.51818722, -0.15390486, 0.0468148, 0.39922136}; + + lstm_input_ = {{2., 3., 3., 4., 1., 1.}}; + lstm_golden_output_ = {{-0.02973187, 0.1229473, 0.20885126, -0.15358765, + -0.03716109, 0.12507336, 0.41193449, -0.20860538, + -0.15053082, 0.09120187, 0.24278517, -0.12222792}}; + } +}; + +TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -257,10 +441,10 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { {n_cell, n_input}, // input_to_cell_weight tensor {n_cell, n_input}, // input_to_output_weight tensor - {n_cell, n_output}, // recurrent_to_input_weight tensor - {n_cell, n_output}, // recurrent_to_forget_weight tensor - {n_cell, n_output}, // recurrent_to_cell_weight tensor - {n_cell, n_output}, // recurrent_to_output_weight tensor + {n_cell, n_output}, // recurrent_to_input_weight_tensor + {n_cell, n_output}, // recurrent_to_forget_weight_tensor + {n_cell, n_output}, // recurrent_to_cell_weight_tensor + {n_cell, n_output}, // recurrent_to_output_weight_tensor {0}, // cell_to_input_weight tensor {0}, // cell_to_forget_weight tensor @@ -275,79 +459,137 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, - -0.34550029, 0.04266912, -0.15680569, - -0.34856534, 0.43890524}); - - lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, - -0.20583314, 0.44344562, 0.22077113, - -0.29909778}); - - lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, - -0.31343272, -0.40032279, 0.44781327, - 0.01387155, -0.35593212}); - - lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, - 0.40525138, 0.44272184, 0.03897077, -0.1556896, - 0.19487578}); + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); - lstm.SetInputGateBias({0., 0., 0., 0.}); + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); - lstm.SetCellBias({0., 0., 0., 0.}); + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - lstm.SetForgetGateBias({1., 1., 1., 1.}); - - lstm.SetOutputGateBias({0., 0., 0., 0.}); - - lstm.SetRecurrentToInputWeights( - {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, - -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, - -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}); - - lstm.SetRecurrentToCellWeights( - {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, - -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, - -0.46367589, 0.26016325, -0.03894562, -0.16368064}); + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); - lstm.SetRecurrentToForgetWeights( - {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, - -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, - 0.28053468, 0.01560611, -0.20127171, -0.01140004}); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} - lstm.SetRecurrentToOutputWeights( - {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, - 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, - -0.51818722, -0.15390486, 0.0468148, 0.39922136}); +TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; - static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; - static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126, - -0.15358765, -0.03716109, 0.12507336, - 0.41193449, -0.20860538, -0.15053082, - 0.09120187, 0.24278517, -0.12222792}; + HybridLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); // Resetting cell_state and output_state lstm.ResetCellState(); lstm.ResetOutputState(); - const int input_sequence_size = - sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs()); - for (int i = 0; i < input_sequence_size; i++) { - float* batch0_start = lstm_input + i * lstm.num_inputs(); - float* batch0_end = batch0_start + lstm.num_inputs(); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, + /*tolerance=*/0.0157651); +} - lstm.SetInput(0, batch0_start, batch0_end); +class CifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726, + 0.05100781, 0.04717243, 0.48944736, + -0.38535351, -0.17212132}; - lstm.Invoke(); + input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988, + -0.3633365, -0.22755712, 0.28253698, + 0.24407166, 0.33826375}; - float* golden_start = lstm_golden_output + i * lstm.num_outputs(); - float* golden_end = golden_start + lstm.num_outputs(); - std::vector expected; - expected.insert(expected.end(), golden_start, golden_end); - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + input_to_output_weights_ = {0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_cell_weights_ = { + 0.54066205, -0.32668582, -0.43562764, -0.56094903, + 0.42957711, 0.01841056, -0.32764608, -0.33027974, + -0.10826075, 0.20675004, 0.19069612, -0.03026325, + -0.54532051, 0.33003211, 0.44901288, 0.21193194}; + + recurrent_to_forget_weights_ = { + -0.13832897, -0.0515101, -0.2359007, -0.16661474, + -0.14340827, 0.36986142, 0.23414481, 0.55899, + 0.10798943, -0.41174671, 0.17751795, -0.34484994, + -0.35874045, -0.11352962, 0.27268326, 0.54058349}; + + recurrent_to_output_weights_ = { + 0.41613156, 0.42610586, -0.16495961, -0.5663873, + 0.30579174, -0.05115908, -0.33941799, 0.23364776, + 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}; + + cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408, + 0.31544167}; + cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703, + -0.77109635}; + + lstm_input_ = {{2., 3., 3., 4., 1., 1.}}; + lstm_golden_output_ = {{-0.36444446, -0.00352185, 0.12886585, -0.05163646, + -0.42312205, -0.01218222, 0.24201041, -0.08124574, + -0.358325, -0.04621704, 0.21641694, -0.06471302}}; } -} +}; -TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { +TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -385,74 +627,689 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, - 0.04717243, 0.48944736, -0.38535351, - -0.17212132}); - - lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, - -0.3633365, -0.22755712, 0.28253698, 0.24407166, - 0.33826375}); - - lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, - -0.09426838, -0.44257352, 0.54939759, - 0.01533556, 0.42751634}); - - lstm.SetCellBias({0., 0., 0., 0.}); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); - lstm.SetForgetGateBias({1., 1., 1., 1.}); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); - lstm.SetOutputGateBias({0., 0., 0., 0.}); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - lstm.SetRecurrentToCellWeights( - {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, - 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, - 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, - 0.21193194}); + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); - lstm.SetRecurrentToForgetWeights( - {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, - 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, - -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349}); + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); - lstm.SetRecurrentToOutputWeights( - {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, - -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, - 0.50248802, 0.26114327, -0.43736315, 0.33149987}); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} - lstm.SetCellToForgetWeights( - {0.47485286, -0.51955009, -0.24458408, 0.31544167}); - lstm.SetCellToOutputWeights( - {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); +TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; - static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; - static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585, - -0.05163646, -0.42312205, -0.01218222, - 0.24201041, -0.08124574, -0.358325, - -0.04621704, 0.21641694, -0.06471302}; + HybridLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); // Resetting cell_state and output_state lstm.ResetCellState(); lstm.ResetOutputState(); - const int input_sequence_size = - sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs()); - for (int i = 0; i < input_sequence_size; i++) { - float* batch0_start = lstm_input + i * lstm.num_inputs(); - float* batch0_end = batch0_start + lstm.num_inputs(); - - lstm.SetInput(0, batch0_start, batch0_end); - - lstm.Invoke(); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573); +} - float* golden_start = lstm_golden_output + i * lstm.num_outputs(); - float* golden_end = golden_start + lstm.num_outputs(); - std::vector expected; - expected.insert(expected.end(), golden_start, golden_end); - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); +class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_input_weights_ = { + 0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, + 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, + -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, + -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, + -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, + -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, + -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, + 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, + 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, + 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, + -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, + 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, + -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, + -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, + -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, + 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, + -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, + -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, + -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, + -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}; + + input_to_forget_weights_ = { + -0.0018401089, -0.004852237, 0.03698424, 0.014181704, + 0.028273236, -0.016726194, -0.05249759, -0.10204261, + 0.00861066, -0.040979505, -0.009899187, 0.01923892, + -0.028177269, -0.08535103, -0.14585495, 0.10662567, + -0.01909731, -0.017883534, -0.0047269356, -0.045103323, + 0.0030784295, 0.076784775, 0.07463696, 0.094531395, + 0.0814421, -0.12257899, -0.033945758, -0.031303465, + 0.045630626, 0.06843887, -0.13492945, -0.012480007, + -0.0811829, -0.07224499, -0.09628791, 0.045100946, + 0.0012300825, 0.013964662, 0.099372394, 0.02543059, + 0.06958324, 0.034257296, 0.0482646, 0.06267997, + 0.052625068, 0.12784666, 0.07077897, 0.025725935, + 0.04165009, 0.07241905, 0.018668644, -0.037377294, + -0.06277783, -0.08833636, -0.040120605, -0.011405586, + -0.007808335, -0.010301386, -0.005102167, 0.027717464, + 0.05483423, 0.11449111, 0.11289652, 0.10939839, + 0.13396506, -0.08402166, -0.01901462, -0.044678304, + -0.07720565, 0.014350063, -0.11757958, -0.0652038, + -0.08185733, -0.076754324, -0.092614375, 0.10405491, + 0.052960336, 0.035755895, 0.035839386, -0.012540553, + 0.036881298, 0.02913376, 0.03420159, 0.05448447, + -0.054523353, 0.02582715, 0.02327355, -0.011857179, + -0.0011980024, -0.034641717, -0.026125094, -0.17582615, + -0.15923657, -0.27486774, -0.0006143371, 0.0001771948, + -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}; + + input_to_cell_weights_ = { + -0.04580283, -0.09549462, -0.032418985, -0.06454633, + -0.043528453, 0.043018587, -0.049152344, -0.12418144, + -0.078985475, -0.07596889, 0.019484362, -0.11434962, + -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, + -0.025034338, -0.0028890965, 0.048929527, 0.06235075, + 0.10665918, -0.032036792, -0.08505916, -0.10843358, + -0.13002433, -0.036816437, -0.02130134, -0.016518239, + 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, + -0.10652836, -0.1037554, -0.13056071, -0.03266643, + -0.033702414, -0.006473424, -0.04611692, 0.014419339, + -0.025174323, 0.0396852, 0.081777506, 0.06157468, + 0.10210095, -0.009658194, 0.046511717, 0.03603906, + 0.0069369148, 0.015960095, -0.06507666, 0.09551598, + 0.053568836, 0.06408714, 0.12835667, -0.008714329, + -0.20211966, -0.12093674, 0.029450472, 0.2849013, + -0.029227901, 0.1164364, -0.08560263, 0.09941786, + -0.036999565, -0.028842626, -0.0033637602, -0.017012902, + -0.09720865, -0.11193351, -0.029155117, -0.017936034, + -0.009768936, -0.04223324, -0.036159635, 0.06505112, + -0.021742892, -0.023377212, -0.07221364, -0.06430552, + 0.05453865, 0.091149814, 0.06387331, 0.007518393, + 0.055960953, 0.069779344, 0.046411168, 0.10509911, + 0.07463894, 0.0075130584, 0.012850982, 0.04555431, + 0.056955688, 0.06555285, 0.050801456, -0.009862683, + 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}; + + input_to_output_weights_ = { + -0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, + -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, + 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, + -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, + -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, + 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, + -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, + -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, + -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, + -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, + 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, + 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, + 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, + -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, + 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, + 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, + -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, + 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, + -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, + -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}; + + input_gate_bias_ = {0.02234832, 0.14757581, 0.18176508, 0.10380666, + 0.053110216, -0.06928846, -0.13942584, -0.11816189, + 0.19483899, 0.03652339, -0.10250295, 0.036714908, + -0.18426876, 0.036065217, 0.21810818, 0.02383196, + -0.043370757, 0.08690144, -0.04444982, 0.00030581196}; + + forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696, + 0.11098921, 0.15378423, 0.09263801, 0.09790885, + 0.09508917, 0.061199076, 0.07665568, -0.015443159, + -0.03499149, 0.046190713, 0.08895977, 0.10899629, + 0.40694186, 0.06030037, 0.012413437, -0.06108739}; + + cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132, 0.033463873, + -0.1483596, -0.10639995, -0.091433935, 0.058573797, + -0.06809782, -0.07889636, -0.043246906, -0.09829136, + -0.4279842, 0.034901652, 0.18797937, 0.0075234566, + 0.016178843, 0.1749513, 0.13975595, 0.92058027}; + + output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469, 0.12648113, + 0.027195795, 0.35373217, -0.018957434, 0.008907322, + -0.0762701, 0.12018895, 0.04216877, 0.0022856654, + 0.040952638, 0.3147856, 0.08225149, -0.057416286, + -0.14995944, -0.008040261, 0.13208859, 0.029760877}; + + recurrent_to_input_weights_ = { + -0.001374326, -0.078856036, 0.10672688, 0.029162422, + -0.11585556, 0.02557986, -0.13446963, -0.035785314, + -0.01244275, 0.025961924, -0.02337298, -0.044228926, + -0.055839065, -0.046598054, -0.010546039, -0.06900766, + 0.027239809, 0.022582639, -0.013296484, -0.05459212, + 0.08981, -0.045407712, 0.08682226, -0.06867011, + -0.14390695, -0.02916037, 0.000996957, 0.091420636, + 0.14283475, -0.07390571, -0.06402044, 0.062524505, + -0.093129106, 0.04860203, -0.08364217, -0.08119002, + 0.009352075, 0.22920375, 0.0016303885, 0.11583097, + -0.13732095, 0.012405723, -0.07551853, 0.06343048, + 0.12162708, -0.031923793, -0.014335606, 0.01790974, + -0.10650317, -0.0724401, 0.08554849, -0.05727212, + 0.06556731, -0.042729504, -0.043227166, 0.011683251, + -0.013082158, -0.029302018, -0.010899579, -0.062036745, + -0.022509435, -0.00964907, -0.01567329, 0.04260106, + -0.07787477, -0.11576462, 0.017356863, 0.048673786, + -0.017577527, -0.05527947, -0.082487635, -0.040137455, + -0.10820036, -0.04666372, 0.022746278, -0.07851417, + 0.01068115, 0.032956902, 0.022433773, 0.0026891115, + 0.08944216, -0.0685835, 0.010513544, 0.07228705, + 0.02032331, -0.059686817, -0.0005566496, -0.086984694, + 0.040414046, -0.1380399, 0.094208956, -0.05722982, + 0.012092817, -0.04989123, -0.086576, -0.003399834, + -0.04696032, -0.045747425, 0.10091314, 0.048676282, + -0.029037097, 0.031399418, -0.0040285117, 0.047237843, + 0.09504992, 0.041799378, -0.049185462, -0.031518843, + -0.10516937, 0.026374253, 0.10058866, -0.0033195973, + -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, + -0.10167381, 0.042500053, -0.01447153, 0.06464186, + -0.017142897, 0.03312627, 0.009205989, 0.024138335, + -0.011337001, 0.035530265, -0.010912711, 0.0706555, + -0.005894094, 0.051841937, -0.1401738, -0.02351249, + 0.0365468, 0.07590991, 0.08838724, 0.021681072, + -0.10086113, 0.019608743, -0.06195883, 0.077335775, + 0.023646897, -0.095322326, 0.02233014, 0.09756986, + -0.048691444, -0.009579111, 0.07595467, 0.11480546, + -0.09801813, 0.019894179, 0.08502348, 0.004032281, + 0.037211012, 0.068537936, -0.048005626, -0.091520436, + -0.028379958, -0.01556313, 0.06554592, -0.045599163, + -0.01672207, -0.020169014, -0.011877351, -0.20212261, + 0.010889619, 0.0047078193, 0.038385306, 0.08540671, + -0.017140968, -0.0035865551, 0.016678626, 0.005633034, + 0.015963363, 0.00871737, 0.060130805, 0.028611384, + 0.10109069, -0.015060172, -0.07894427, 0.06401885, + 0.011584063, -0.024466386, 0.0047652307, -0.09041358, + 0.030737216, -0.0046374933, 0.14215417, -0.11823516, + 0.019899689, 0.006106124, -0.027092824, 0.0786356, + 0.05052217, -0.058925, -0.011402121, -0.024987547, + -0.0013661642, -0.06832946, -0.015667673, -0.1083353, + -0.00096863037, -0.06988685, -0.053350925, -0.027275559, + -0.033664223, -0.07978348, -0.025200296, -0.017207067, + -0.058403496, -0.055697463, 0.005798788, 0.12965427, + -0.062582195, 0.0013350133, -0.10482091, 0.0379771, + 0.072521195, -0.0029455067, -0.13797039, -0.03628521, + 0.013806405, -0.017858358, -0.01008298, -0.07700066, + -0.017081132, 0.019358726, 0.0027079724, 0.004635139, + 0.062634714, -0.02338735, -0.039547626, -0.02050681, + 0.03385117, -0.083611414, 0.002862572, -0.09421313, + 0.058618143, -0.08598433, 0.00972939, 0.023867095, + -0.053934585, -0.023203006, 0.07452513, -0.048767887, + -0.07314807, -0.056307215, -0.10433547, -0.06440842, + 0.04328182, 0.04389765, -0.020006588, -0.09076438, + -0.11652589, -0.021705797, 0.03345259, -0.010329105, + -0.025767034, 0.013057034, -0.07316461, -0.10145612, + 0.06358255, 0.18531723, 0.07759293, 0.12006465, + 0.1305557, 0.058638252, -0.03393652, 0.09622831, + -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, + -0.005644518, 0.06857898, -0.12598175, -0.035084512, + 0.03156317, -0.12794146, -0.031963028, 0.04692781, + 0.030070418, 0.0071660685, -0.095516115, -0.004643372, + 0.040170413, -0.062104587, -0.0037324072, 0.0554317, + 0.08184801, -0.019164372, 0.06791302, 0.034257166, + -0.10307039, 0.021943003, 0.046745934, 0.0790918, + -0.0265588, -0.007824208, 0.042546265, -0.00977924, + -0.0002440307, -0.017384544, -0.017990116, 0.12252321, + -0.014512694, -0.08251313, 0.08861942, 0.13589665, + 0.026351685, 0.012641483, 0.07466548, 0.044301085, + -0.045414884, -0.051112458, 0.03444247, -0.08502782, + -0.04106223, -0.028126027, 0.028473156, 0.10467447}; + + recurrent_to_cell_weights_ = { + -0.037322544, 0.018592842, 0.0056175636, -0.06253426, + 0.055647098, -0.05713207, -0.05626563, 0.005559383, + 0.03375411, -0.025757805, -0.088049285, 0.06017052, + -0.06570978, 0.007384076, 0.035123326, -0.07920549, + 0.053676967, 0.044480428, -0.07663568, 0.0071805613, + 0.08089997, 0.05143358, 0.038261272, 0.03339287, + -0.027673481, 0.044746667, 0.028349208, 0.020090483, + -0.019443132, -0.030755889, -0.0040000007, 0.04465846, + -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, + -0.10893326, 0.076739706, -0.08509834, -0.027997585, + 0.037871376, 0.01449768, -0.09002357, -0.06111149, + -0.046195522, 0.0422062, -0.005683705, -0.1253618, + -0.012925729, -0.04890792, 0.06985068, 0.037654128, + 0.03398274, -0.004781977, 0.007032333, -0.031787455, + 0.010868644, -0.031489216, 0.09525667, 0.013939797, + 0.0058680447, 0.0167067, 0.02668468, -0.04797466, + -0.048885044, -0.12722108, 0.035304096, 0.06554885, + 0.00972396, -0.039238118, -0.05159735, -0.11329045, + 0.1613692, -0.03750952, 0.06529313, -0.071974665, + -0.11769596, 0.015524369, -0.0013754242, -0.12446318, + 0.02786344, -0.014179351, 0.005264273, 0.14376344, + 0.015983658, 0.03406988, -0.06939408, 0.040699873, + 0.02111075, 0.09669095, 0.041345075, -0.08316494, + -0.07684199, -0.045768797, 0.032298047, -0.041805092, + 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, + -0.024950314, 0.11574242, 0.04508852, -0.04335324, + 0.06760663, -0.027437469, 0.07216407, 0.06977076, + -0.05438599, 0.034033038, -0.028602652, 0.05346137, + 0.043184172, -0.037189785, 0.10420091, 0.00882477, + -0.054019816, -0.074273005, -0.030617684, -0.0028467078, + 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, + 0.04361412, -0.007001822, 0.09631092, -0.06702025, + -0.042049985, -0.035070654, -0.04103342, -0.10273396, + 0.0544271, 0.037184782, -0.13150354, -0.0058036847, + -0.008264958, 0.042035464, 0.05891794, 0.029673764, + 0.0063542654, 0.044788733, 0.054816857, 0.062257513, + -0.00093483756, 0.048938446, -0.004952862, -0.007730018, + -0.04043371, -0.017094059, 0.07229206, -0.023670016, + -0.052195564, -0.025616996, -0.01520939, 0.045104615, + -0.007376126, 0.003533447, 0.006570588, 0.056037236, + 0.12436656, 0.051817212, 0.028532185, -0.08686856, + 0.11868599, 0.07663395, -0.07323171, 0.03463402, + -0.050708205, -0.04458982, -0.11590894, 0.021273347, + 0.1251325, -0.15313013, -0.12224372, 0.17228661, + 0.023029093, 0.086124025, 0.006445803, -0.03496501, + 0.028332196, 0.04449512, -0.042436164, -0.026587414, + -0.006041347, -0.09292539, -0.05678812, 0.03897832, + 0.09465633, 0.008115513, -0.02171956, 0.08304309, + 0.071401566, 0.019622514, 0.032163795, -0.004167056, + 0.02295182, 0.030739572, 0.056506045, 0.004612461, + 0.06524936, 0.059999723, 0.046395954, -0.0045512207, + -0.1335546, -0.030136576, 0.11584653, -0.014678886, + 0.0020118146, -0.09688814, -0.0790206, 0.039770417, + -0.0329582, 0.07922767, 0.029322514, 0.026405897, + 0.04207835, -0.07073373, 0.063781224, 0.0859677, + -0.10925287, -0.07011058, 0.048005477, 0.03438226, + -0.09606514, -0.006669445, -0.043381985, 0.04240257, + -0.06955775, -0.06769346, 0.043903265, -0.026784198, + -0.017840602, 0.024307009, -0.040079936, -0.019946516, + 0.045318738, -0.12233574, 0.026170589, 0.0074471775, + 0.15978073, 0.10185836, 0.10298046, -0.015476589, + -0.039390966, -0.072174534, 0.0739445, -0.1211869, + -0.0347889, -0.07943156, 0.014809798, -0.12412325, + -0.0030663363, 0.039695457, 0.0647603, -0.08291318, + -0.018529687, -0.004423833, 0.0037507233, 0.084633216, + -0.01514876, -0.056505352, -0.012800942, -0.06994386, + 0.012962922, -0.031234352, 0.07029052, 0.016418684, + 0.03618972, 0.055686004, -0.08663945, -0.017404709, + -0.054761406, 0.029065743, 0.052404847, 0.020238016, + 0.0048197987, -0.0214882, 0.07078733, 0.013016777, + 0.06262858, 0.009184685, 0.020785125, -0.043904778, + -0.0270329, -0.03299152, -0.060088247, -0.015162964, + -0.001828936, 0.12642565, -0.056757294, 0.013586685, + 0.09232601, -0.035886683, 0.06000002, 0.05229691, + -0.052580316, -0.082029596, -0.010794592, 0.012947712, + -0.036429964, -0.085508935, -0.13127148, -0.017744139, + 0.031502828, 0.036232427, -0.031581745, 0.023051167, + -0.05325106, -0.03421577, 0.028793324, -0.034633752, + -0.009881397, -0.043551125, -0.018609839, 0.0019097115, + -0.008799762, 0.056595087, 0.0022273948, 0.055752404}; + + recurrent_to_forget_weights_ = { + -0.057784554, -0.026057621, -0.068447545, -0.022581743, + 0.14811787, 0.10826372, 0.09471067, 0.03987225, + -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, + 0.08414449, -0.022036452, -0.00066928595, -0.09203576, + 0.032950465, -0.10985798, -0.023809856, 0.0021431844, + -0.02196096, -0.00326074, 0.00058621005, -0.074678116, + -0.06193199, 0.055729095, 0.03736828, 0.020123724, + 0.061878487, -0.04729229, 0.034919553, -0.07585433, + -0.04421272, -0.044019096, 0.085488975, 0.04058006, + -0.06890133, -0.030951202, -0.024628663, -0.07672815, + 0.034293607, 0.08556707, -0.05293577, -0.033561368, + -0.04899627, 0.0241671, 0.015736353, -0.095442444, + -0.029564252, 0.016493602, -0.035026584, 0.022337519, + -0.026871363, 0.004780428, 0.0077918363, -0.03601621, + 0.016435321, -0.03263031, -0.09543275, -0.047392778, + 0.013454138, 0.028934088, 0.01685226, -0.086110644, + -0.046250615, -0.01847454, 0.047608484, 0.07339695, + 0.034546845, -0.04881143, 0.009128804, -0.08802852, + 0.03761666, 0.008096139, -0.014454086, 0.014361001, + -0.023502491, -0.0011840804, -0.07607001, 0.001856849, + -0.06509276, -0.006021153, -0.08570962, -0.1451793, + 0.060212336, 0.055259194, 0.06974018, 0.049454916, + -0.027794661, -0.08077226, -0.016179763, 0.1169753, + 0.17213494, -0.0056326236, -0.053934924, -0.0124349, + -0.11520337, 0.05409887, 0.088759385, 0.0019655675, + 0.0042065294, 0.03881498, 0.019844765, 0.041858196, + -0.05695512, 0.047233116, 0.038937137, -0.06542224, + 0.014429736, -0.09719407, 0.13908425, -0.05379757, + 0.012321099, 0.082840554, -0.029899208, 0.044217527, + 0.059855383, 0.07711018, -0.045319796, 0.0948846, + -0.011724666, -0.0033288454, -0.033542685, -0.04764985, + -0.13873616, 0.040668588, 0.034832682, -0.015319203, + -0.018715994, 0.046002675, 0.0599172, -0.043107376, + 0.0294216, -0.002314414, -0.022424703, 0.0030315618, + 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, + 0.12375372, -0.0006038222, 0.029104086, 0.087442465, + 0.052958444, 0.07558703, 0.04817258, 0.044462286, + -0.015213451, -0.08783778, -0.0561384, -0.003008196, + 0.047060397, -0.002058388, 0.03429439, -0.018839769, + 0.024734668, 0.024614193, -0.042046934, 0.09597743, + -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, + -0.02558259, -0.022822596, -0.023273505, -0.02464396, + -0.10991725, -0.006240552, 0.0074488563, 0.024044557, + 0.04383914, -0.046476185, 0.028658995, 0.060410924, + 0.050786525, 0.009452605, -0.0073054377, -0.024810238, + 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, + 0.015898481, 0.021362653, -0.030262267, 0.016587038, + -0.011442813, 0.041154444, -0.007631438, -0.03423484, + -0.010977775, 0.036152758, 0.0066366293, 0.11915515, + 0.02318443, -0.041350313, 0.021485701, -0.10906167, + -0.028218046, -0.00954771, 0.020531068, -0.11995105, + -0.03672871, 0.024019798, 0.014255957, -0.05221243, + -0.00661567, -0.04630967, 0.033188973, 0.10107534, + -0.014027541, 0.030796422, -0.10270911, -0.035999842, + 0.15443139, 0.07684145, 0.036571592, -0.035900835, + -0.0034699554, 0.06209149, 0.015920248, -0.031122351, + -0.03858649, 0.01849943, 0.13872518, 0.01503974, + 0.069941424, -0.06948533, -0.0088794185, 0.061282158, + -0.047401894, 0.03100163, -0.041533746, -0.10430945, + 0.044574402, -0.01425562, -0.024290353, 0.034563623, + 0.05866852, 0.023947537, -0.09445152, 0.035450947, + 0.02247216, -0.0042998926, 0.061146557, -0.10250651, + 0.020881841, -0.06747029, 0.10062043, -0.0023941975, + 0.03532124, -0.016341697, 0.09685456, -0.016764693, + 0.051808182, 0.05875331, -0.04536488, 0.001626336, + -0.028892258, -0.01048663, -0.009793449, -0.017093895, + 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, + -0.001845119, -0.03551521, 0.0018358806, 0.05763657, + -0.01769146, 0.040995963, 0.02235177, -0.060430344, + 0.11475477, -0.023854522, 0.10071741, 0.0686208, + -0.014250481, 0.034261297, 0.047418304, 0.08562733, + -0.030519066, 0.0060542435, 0.014653856, -0.038836084, + 0.04096551, 0.032249358, -0.08355519, -0.026823482, + 0.056386515, -0.010401743, -0.028396193, 0.08507674, + 0.014410365, 0.020995233, 0.17040324, 0.11511526, + 0.02459721, 0.0066619175, 0.025853224, -0.023133837, + -0.081302024, 0.017264642, -0.009585969, 0.09491168, + -0.051313367, 0.054532815, -0.014298593, 0.10657464, + 0.007076659, 0.10964551, 0.0409152, 0.008275321, + -0.07283536, 0.07937492, 0.04192024, -0.1075027}; + + recurrent_to_output_weights_ = { + 0.025825322, -0.05813119, 0.09495884, -0.045984812, + -0.01255415, -0.0026479573, -0.08196161, -0.054914974, + -0.0046604523, -0.029587349, -0.044576716, -0.07480124, + -0.082868785, 0.023254942, 0.027502948, -0.0039728214, + -0.08683098, -0.08116779, -0.014675607, -0.037924774, + -0.023314456, -0.007401714, -0.09255757, 0.029460307, + -0.08829125, -0.005139627, -0.08989442, -0.0555066, + 0.13596267, -0.025062224, -0.048351806, -0.03850004, + 0.07266485, -0.022414139, 0.05940088, 0.075114764, + 0.09597592, -0.010211725, -0.0049794707, -0.011523867, + -0.025980417, 0.072999895, 0.11091378, -0.081685916, + 0.014416728, 0.043229222, 0.034178585, -0.07530371, + 0.035837382, -0.085607, -0.007721233, -0.03287832, + -0.043848954, -0.06404588, -0.06632928, -0.073643476, + 0.008214239, -0.045984086, 0.039764922, 0.03474462, + 0.060612556, -0.080590084, 0.049127717, 0.04151091, + -0.030063879, 0.008801774, -0.023021035, -0.019558564, + 0.05158114, -0.010947698, -0.011825728, 0.0075720972, + 0.0699727, -0.0039981045, 0.069350146, 0.08799282, + 0.016156472, 0.035502106, 0.11695009, 0.006217345, + 0.13392477, -0.037875112, 0.025745004, 0.08940699, + -0.00924166, 0.0046702605, -0.036598757, -0.08811812, + 0.10522024, -0.032441203, 0.008176899, -0.04454919, + 0.07058152, 0.0067963637, 0.039206743, 0.03259838, + 0.03725492, -0.09515802, 0.013326398, -0.052055415, + -0.025676316, 0.03198509, -0.015951829, -0.058556724, + 0.036879618, 0.043357447, 0.028362012, -0.05908629, + 0.0059240665, -0.04995891, -0.019187413, 0.0276265, + -0.01628143, 0.0025863599, 0.08800015, 0.035250366, + -0.022165963, -0.07328642, -0.009415526, -0.07455109, + 0.11690406, 0.0363299, 0.07411125, 0.042103454, + -0.009660886, 0.019076364, 0.018299393, -0.046004917, + 0.08891175, 0.0431396, -0.026327137, -0.051502608, + 0.08979574, -0.051670972, 0.04940282, -0.07491107, + -0.021240504, 0.022596184, -0.034280192, 0.060163025, + -0.058211457, -0.051837247, -0.01349775, -0.04639988, + -0.035936575, -0.011681591, 0.064818054, 0.0073146066, + -0.021745546, -0.043124277, -0.06471268, -0.07053354, + -0.029321948, -0.05330136, 0.016933719, -0.053782392, + 0.13747959, -0.1361751, -0.11569455, 0.0033329215, + 0.05693899, -0.053219706, 0.063698, 0.07977434, + -0.07924483, 0.06936997, 0.0034815092, -0.007305279, + -0.037325785, -0.07251102, -0.033633437, -0.08677009, + 0.091591336, -0.14165086, 0.021752775, 0.019683983, + 0.0011612234, -0.058154266, 0.049996935, 0.0288841, + -0.0024567875, -0.14345716, 0.010955264, -0.10234828, + 0.1183656, -0.0010731248, -0.023590032, -0.072285876, + -0.0724771, -0.026382286, -0.0014920527, 0.042667855, + 0.0018776858, 0.02986552, 0.009814309, 0.0733756, + 0.12289186, 0.018043943, -0.0458958, 0.049412545, + 0.033632483, 0.05495232, 0.036686596, -0.013781798, + -0.010036754, 0.02576849, -0.08307328, 0.010112348, + 0.042521734, -0.05869831, -0.071689695, 0.03876447, + -0.13275425, -0.0352966, -0.023077697, 0.10285965, + 0.084736146, 0.15568255, -0.00040734606, 0.027835453, + -0.10292561, -0.032401145, 0.10053256, -0.026142767, + -0.08271222, -0.0030240538, -0.016368777, 0.1070414, + 0.042672627, 0.013456989, -0.0437609, -0.022309763, + 0.11576483, 0.04108048, 0.061026827, -0.0190714, + -0.0869359, 0.037901703, 0.0610107, 0.07202949, + 0.01675338, 0.086139716, -0.08795751, -0.014898893, + -0.023771819, -0.01965048, 0.007955471, -0.043740474, + 0.03346837, -0.10549954, 0.090567775, 0.042013682, + -0.03176985, 0.12569028, -0.02421228, -0.029526481, + 0.023851605, 0.031539805, 0.05292009, -0.02344001, + -0.07811758, -0.08834428, 0.10094801, 0.16594367, + -0.06861939, -0.021256343, -0.041093912, -0.06669611, + 0.035498552, 0.021757556, -0.09302526, -0.015403468, + -0.06614931, -0.051798206, -0.013874718, 0.03630673, + 0.010412845, -0.08077351, 0.046185967, 0.0035662893, + 0.03541868, -0.094149634, -0.034814864, 0.003128424, + -0.020674974, -0.03944324, -0.008110165, -0.11113267, + 0.08484226, 0.043586485, 0.040582247, 0.0968012, + -0.065249965, -0.028036479, 0.0050708856, 0.0017462453, + 0.0326779, 0.041296225, 0.09164146, -0.047743853, + -0.015952192, -0.034451712, 0.084197424, -0.05347844, + -0.11768019, 0.085926116, -0.08251791, -0.045081906, + 0.0948852, 0.068401024, 0.024856757, 0.06978981, + -0.057309967, -0.012775832, -0.0032452994, 0.01977615, + -0.041040014, -0.024264973, 0.063464895, 0.05431621, + }; + + cell_to_input_weights_ = { + 0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, + -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, + -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, + 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}; + + cell_to_forget_weights_ = { + -0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, + -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, + -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, + 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}; + + cell_to_output_weights_ = { + 0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, + -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, + -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, + 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}; + + projection_weights_ = { + -0.009802181, 0.09401916, 0.0717386, -0.13895074, + 0.09641832, 0.060420845, 0.08539281, 0.054285463, + 0.061395317, 0.034448683, -0.042991187, 0.019801661, + -0.16840284, -0.015726732, -0.23041931, -0.024478018, + -0.10959692, -0.013875541, 0.18600968, -0.061274476, + 0.0138165, -0.08160894, -0.07661644, 0.032372914, + 0.16169067, 0.22465782, -0.03993472, -0.004017731, + 0.08633481, -0.28869787, 0.08682067, 0.17240396, + 0.014975425, 0.056431185, 0.031037588, 0.16702051, + 0.0077946745, 0.15140012, 0.29405436, 0.120285, + -0.188994, -0.027265169, 0.043389652, -0.022061434, + 0.014777949, -0.20203483, 0.094781205, 0.19100232, + 0.13987629, -0.036132768, -0.06426278, -0.05108664, + 0.13221376, 0.009441198, -0.16715929, 0.15859416, + -0.040437475, 0.050779544, -0.022187516, 0.012166504, + 0.027685808, -0.07675938, -0.0055694645, -0.09444123, + 0.0046453946, 0.050794356, 0.10770313, -0.20790008, + -0.07149004, -0.11425117, 0.008225835, -0.035802525, + 0.14374903, 0.15262283, 0.048710253, 0.1847461, + -0.007487823, 0.11000021, -0.09542012, 0.22619456, + -0.029149994, 0.08527916, 0.009043713, 0.0042746216, + 0.016261552, 0.022461696, 0.12689082, -0.043589946, + -0.12035478, -0.08361797, -0.050666027, -0.1248618, + -0.1275799, -0.071875185, 0.07377272, 0.09944291, + -0.18897448, -0.1593054, -0.06526116, -0.040107165, + -0.004618631, -0.067624845, -0.007576253, 0.10727444, + 0.041546922, -0.20424393, 0.06907816, 0.050412357, + 0.00724631, 0.039827548, 0.12449835, 0.10747581, + 0.13708383, 0.09134148, -0.12617786, -0.06428341, + 0.09956831, 0.1208086, -0.14676677, -0.0727722, + 0.1126304, 0.010139365, 0.015571211, -0.038128063, + 0.022913318, -0.042050496, 0.16842307, -0.060597885, + 0.10531834, -0.06411776, -0.07451711, -0.03410368, + -0.13393489, 0.06534304, 0.003620307, 0.04490757, + 0.05970546, 0.05197996, 0.02839995, 0.10434969, + -0.013699693, -0.028353551, -0.07260381, 0.047201227, + -0.024575593, -0.036445823, 0.07155557, 0.009672501, + -0.02328883, 0.009533515, -0.03606021, -0.07421458, + -0.028082801, -0.2678904, -0.13221288, 0.18419984, + -0.13012612, -0.014588381, -0.035059117, -0.04824723, + 0.07830115, -0.056184657, 0.03277091, 0.025466874, + 0.14494097, -0.12522776, -0.098633975, -0.10766018, + -0.08317623, 0.08594209, 0.07749552, 0.039474737, + 0.1776665, -0.07409566, -0.0477268, 0.29323658, + 0.10801441, 0.1154011, 0.013952499, 0.10739139, + 0.10708251, -0.051456142, 0.0074137426, -0.10430189, + 0.10034707, 0.045594677, 0.0635285, -0.0715442, + -0.089667566, -0.10811871, 0.00026344223, 0.08298446, + -0.009525053, 0.006585689, -0.24567553, -0.09450807, + 0.09648481, 0.026996298, -0.06419476, -0.04752702, + -0.11063944, -0.23441927, -0.17608605, -0.052156363, + 0.067035615, 0.19271925, -0.0032889997, -0.043264326, + 0.09663576, -0.057112187, -0.10100678, 0.0628376, + 0.04447668, 0.017961001, -0.10094388, -0.10190601, + 0.18335468, 0.10494553, -0.052095775, -0.0026118709, + 0.10539724, -0.04383912, -0.042349473, 0.08438151, + -0.1947263, 0.02251204, 0.11216432, -0.10307853, + 0.17351969, -0.039091777, 0.08066188, -0.00561982, + 0.12633002, 0.11335965, -0.0088127935, -0.019777594, + 0.06864014, -0.059751723, 0.016233567, -0.06894641, + -0.28651384, -0.004228674, 0.019708522, -0.16305895, + -0.07468996, -0.0855457, 0.099339016, -0.07580735, + -0.13775392, 0.08434318, 0.08330512, -0.12131499, + 0.031935584, 0.09180414, -0.08876437, -0.08049874, + 0.008753825, 0.03498998, 0.030215185, 0.03907079, + 0.089751154, 0.029194152, -0.03337423, -0.019092513, + 0.04331237, 0.04299654, -0.036394123, -0.12915532, + 0.09793732, 0.07512415, -0.11319543, -0.032502122, + 0.15661901, 0.07671967, -0.005491124, -0.19379048, + -0.218606, 0.21448623, 0.017840758, 0.1416943, + -0.07051762, 0.19488361, 0.02664691, -0.18104725, + -0.09334311, 0.15026465, -0.15493552, -0.057762887, + -0.11604192, -0.262013, -0.01391798, 0.012185008, + 0.11156489, -0.07483202, 0.06693364, -0.26151478, + 0.046425626, 0.036540434, -0.16435726, 0.17338543, + -0.21401681, -0.11385144, -0.08283257, -0.069031075, + 0.030635102, 0.010969227, 0.11109743, 0.010919218, + 0.027526086, 0.13519906, 0.01891392, -0.046839405, + -0.040167913, 0.017953383, -0.09700955, 0.0061885654, + -0.07000971, 0.026893595, -0.038844477, 0.14543656}; + + lstm_input_ = { + {// Batch0: 4 (input_sequence_size) * 5 (n_input) + 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, // step 0 + 0.596268, 0.998386, 0.568695, 0.864524, 0.571277, // step 1 + 0.073204, 0.296072, 0.743333, 0.069199, 0.045348, // step 2 + 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, // step 3 + + {// Batch1: 4 (input_sequence_size) * 5 (n_input) + 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, // step 0 + 0.642421, 0.524260, 0.134799, 0.003639, 0.162482, // step 1 + 0.640394, 0.930399, 0.050782, 0.432485, 0.988078, // step 2 + 0.082922, 0.563329, 0.865614, 0.333232, 0.259916} // step 3 + }; + + lstm_golden_output_ = { + {// Batch0: 4 (input_sequence_size) * 16 (n_output) + -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, + -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, + -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, + 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, + -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, + -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, + 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, + 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, + 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, + 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, + -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, + -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, + 0.0286833, 0.00824207, 0.0264887, 0.0305169}, + {// Batch1: 4 (input_sequence_size) * 16 (n_output) + -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, + -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, + 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, + 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, + -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, + -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, + 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, + 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, + 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, + 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, + -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, + -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, + 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; } -} +}; -TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { +TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 2; const int n_input = 5; const int n_cell = 20; @@ -489,588 +1346,98 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToInputWeights( - {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, - 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, - -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, - -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, - -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, - -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, - -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, - 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, - 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, - 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, - -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, - 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, - -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, - -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, - -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, - 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, - -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, - -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, - -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, - -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}); - - lstm.SetInputToForgetWeights( - {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236, - -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505, - -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495, - 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323, - 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421, - -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887, - -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791, - 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059, - 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068, - 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905, - 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605, - -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464, - 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506, - -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063, - -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375, - 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553, - 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353, - 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717, - -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371, - 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}); - - lstm.SetInputToCellWeights( - {-0.04580283, -0.09549462, -0.032418985, -0.06454633, - -0.043528453, 0.043018587, -0.049152344, -0.12418144, - -0.078985475, -0.07596889, 0.019484362, -0.11434962, - -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, - -0.025034338, -0.0028890965, 0.048929527, 0.06235075, - 0.10665918, -0.032036792, -0.08505916, -0.10843358, - -0.13002433, -0.036816437, -0.02130134, -0.016518239, - 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, - -0.10652836, -0.1037554, -0.13056071, -0.03266643, - -0.033702414, -0.006473424, -0.04611692, 0.014419339, - -0.025174323, 0.0396852, 0.081777506, 0.06157468, - 0.10210095, -0.009658194, 0.046511717, 0.03603906, - 0.0069369148, 0.015960095, -0.06507666, 0.09551598, - 0.053568836, 0.06408714, 0.12835667, -0.008714329, - -0.20211966, -0.12093674, 0.029450472, 0.2849013, - -0.029227901, 0.1164364, -0.08560263, 0.09941786, - -0.036999565, -0.028842626, -0.0033637602, -0.017012902, - -0.09720865, -0.11193351, -0.029155117, -0.017936034, - -0.009768936, -0.04223324, -0.036159635, 0.06505112, - -0.021742892, -0.023377212, -0.07221364, -0.06430552, - 0.05453865, 0.091149814, 0.06387331, 0.007518393, - 0.055960953, 0.069779344, 0.046411168, 0.10509911, - 0.07463894, 0.0075130584, 0.012850982, 0.04555431, - 0.056955688, 0.06555285, 0.050801456, -0.009862683, - 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}); - - lstm.SetInputToOutputWeights( - {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, - -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, - 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, - -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, - -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, - 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, - -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, - -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, - -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, - -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, - 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, - 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, - 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, - -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, - 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, - 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, - -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, - 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, - -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, - -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}); - - lstm.SetInputGateBias( - {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216, - -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339, - -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818, - 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196}); - - lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696, - 0.11098921, 0.15378423, 0.09263801, 0.09790885, - 0.09508917, 0.061199076, 0.07665568, -0.015443159, - -0.03499149, 0.046190713, 0.08895977, 0.10899629, - 0.40694186, 0.06030037, 0.012413437, -0.06108739}); - - lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873, - -0.1483596, -0.10639995, -0.091433935, 0.058573797, - -0.06809782, -0.07889636, -0.043246906, -0.09829136, - -0.4279842, 0.034901652, 0.18797937, 0.0075234566, - 0.016178843, 0.1749513, 0.13975595, 0.92058027}); - - lstm.SetOutputGateBias( - {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795, - 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895, - 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149, - -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877}); - - lstm.SetRecurrentToInputWeights( - {-0.001374326, -0.078856036, 0.10672688, 0.029162422, - -0.11585556, 0.02557986, -0.13446963, -0.035785314, - -0.01244275, 0.025961924, -0.02337298, -0.044228926, - -0.055839065, -0.046598054, -0.010546039, -0.06900766, - 0.027239809, 0.022582639, -0.013296484, -0.05459212, - 0.08981, -0.045407712, 0.08682226, -0.06867011, - -0.14390695, -0.02916037, 0.000996957, 0.091420636, - 0.14283475, -0.07390571, -0.06402044, 0.062524505, - -0.093129106, 0.04860203, -0.08364217, -0.08119002, - 0.009352075, 0.22920375, 0.0016303885, 0.11583097, - -0.13732095, 0.012405723, -0.07551853, 0.06343048, - 0.12162708, -0.031923793, -0.014335606, 0.01790974, - -0.10650317, -0.0724401, 0.08554849, -0.05727212, - 0.06556731, -0.042729504, -0.043227166, 0.011683251, - -0.013082158, -0.029302018, -0.010899579, -0.062036745, - -0.022509435, -0.00964907, -0.01567329, 0.04260106, - -0.07787477, -0.11576462, 0.017356863, 0.048673786, - -0.017577527, -0.05527947, -0.082487635, -0.040137455, - -0.10820036, -0.04666372, 0.022746278, -0.07851417, - 0.01068115, 0.032956902, 0.022433773, 0.0026891115, - 0.08944216, -0.0685835, 0.010513544, 0.07228705, - 0.02032331, -0.059686817, -0.0005566496, -0.086984694, - 0.040414046, -0.1380399, 0.094208956, -0.05722982, - 0.012092817, -0.04989123, -0.086576, -0.003399834, - -0.04696032, -0.045747425, 0.10091314, 0.048676282, - -0.029037097, 0.031399418, -0.0040285117, 0.047237843, - 0.09504992, 0.041799378, -0.049185462, -0.031518843, - -0.10516937, 0.026374253, 0.10058866, -0.0033195973, - -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, - -0.10167381, 0.042500053, -0.01447153, 0.06464186, - -0.017142897, 0.03312627, 0.009205989, 0.024138335, - -0.011337001, 0.035530265, -0.010912711, 0.0706555, - -0.005894094, 0.051841937, -0.1401738, -0.02351249, - 0.0365468, 0.07590991, 0.08838724, 0.021681072, - -0.10086113, 0.019608743, -0.06195883, 0.077335775, - 0.023646897, -0.095322326, 0.02233014, 0.09756986, - -0.048691444, -0.009579111, 0.07595467, 0.11480546, - -0.09801813, 0.019894179, 0.08502348, 0.004032281, - 0.037211012, 0.068537936, -0.048005626, -0.091520436, - -0.028379958, -0.01556313, 0.06554592, -0.045599163, - -0.01672207, -0.020169014, -0.011877351, -0.20212261, - 0.010889619, 0.0047078193, 0.038385306, 0.08540671, - -0.017140968, -0.0035865551, 0.016678626, 0.005633034, - 0.015963363, 0.00871737, 0.060130805, 0.028611384, - 0.10109069, -0.015060172, -0.07894427, 0.06401885, - 0.011584063, -0.024466386, 0.0047652307, -0.09041358, - 0.030737216, -0.0046374933, 0.14215417, -0.11823516, - 0.019899689, 0.006106124, -0.027092824, 0.0786356, - 0.05052217, -0.058925, -0.011402121, -0.024987547, - -0.0013661642, -0.06832946, -0.015667673, -0.1083353, - -0.00096863037, -0.06988685, -0.053350925, -0.027275559, - -0.033664223, -0.07978348, -0.025200296, -0.017207067, - -0.058403496, -0.055697463, 0.005798788, 0.12965427, - -0.062582195, 0.0013350133, -0.10482091, 0.0379771, - 0.072521195, -0.0029455067, -0.13797039, -0.03628521, - 0.013806405, -0.017858358, -0.01008298, -0.07700066, - -0.017081132, 0.019358726, 0.0027079724, 0.004635139, - 0.062634714, -0.02338735, -0.039547626, -0.02050681, - 0.03385117, -0.083611414, 0.002862572, -0.09421313, - 0.058618143, -0.08598433, 0.00972939, 0.023867095, - -0.053934585, -0.023203006, 0.07452513, -0.048767887, - -0.07314807, -0.056307215, -0.10433547, -0.06440842, - 0.04328182, 0.04389765, -0.020006588, -0.09076438, - -0.11652589, -0.021705797, 0.03345259, -0.010329105, - -0.025767034, 0.013057034, -0.07316461, -0.10145612, - 0.06358255, 0.18531723, 0.07759293, 0.12006465, - 0.1305557, 0.058638252, -0.03393652, 0.09622831, - -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, - -0.005644518, 0.06857898, -0.12598175, -0.035084512, - 0.03156317, -0.12794146, -0.031963028, 0.04692781, - 0.030070418, 0.0071660685, -0.095516115, -0.004643372, - 0.040170413, -0.062104587, -0.0037324072, 0.0554317, - 0.08184801, -0.019164372, 0.06791302, 0.034257166, - -0.10307039, 0.021943003, 0.046745934, 0.0790918, - -0.0265588, -0.007824208, 0.042546265, -0.00977924, - -0.0002440307, -0.017384544, -0.017990116, 0.12252321, - -0.014512694, -0.08251313, 0.08861942, 0.13589665, - 0.026351685, 0.012641483, 0.07466548, 0.044301085, - -0.045414884, -0.051112458, 0.03444247, -0.08502782, - -0.04106223, -0.028126027, 0.028473156, 0.10467447}); - - lstm.SetRecurrentToForgetWeights( - {-0.057784554, -0.026057621, -0.068447545, -0.022581743, - 0.14811787, 0.10826372, 0.09471067, 0.03987225, - -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, - 0.08414449, -0.022036452, -0.00066928595, -0.09203576, - 0.032950465, -0.10985798, -0.023809856, 0.0021431844, - -0.02196096, -0.00326074, 0.00058621005, -0.074678116, - -0.06193199, 0.055729095, 0.03736828, 0.020123724, - 0.061878487, -0.04729229, 0.034919553, -0.07585433, - -0.04421272, -0.044019096, 0.085488975, 0.04058006, - -0.06890133, -0.030951202, -0.024628663, -0.07672815, - 0.034293607, 0.08556707, -0.05293577, -0.033561368, - -0.04899627, 0.0241671, 0.015736353, -0.095442444, - -0.029564252, 0.016493602, -0.035026584, 0.022337519, - -0.026871363, 0.004780428, 0.0077918363, -0.03601621, - 0.016435321, -0.03263031, -0.09543275, -0.047392778, - 0.013454138, 0.028934088, 0.01685226, -0.086110644, - -0.046250615, -0.01847454, 0.047608484, 0.07339695, - 0.034546845, -0.04881143, 0.009128804, -0.08802852, - 0.03761666, 0.008096139, -0.014454086, 0.014361001, - -0.023502491, -0.0011840804, -0.07607001, 0.001856849, - -0.06509276, -0.006021153, -0.08570962, -0.1451793, - 0.060212336, 0.055259194, 0.06974018, 0.049454916, - -0.027794661, -0.08077226, -0.016179763, 0.1169753, - 0.17213494, -0.0056326236, -0.053934924, -0.0124349, - -0.11520337, 0.05409887, 0.088759385, 0.0019655675, - 0.0042065294, 0.03881498, 0.019844765, 0.041858196, - -0.05695512, 0.047233116, 0.038937137, -0.06542224, - 0.014429736, -0.09719407, 0.13908425, -0.05379757, - 0.012321099, 0.082840554, -0.029899208, 0.044217527, - 0.059855383, 0.07711018, -0.045319796, 0.0948846, - -0.011724666, -0.0033288454, -0.033542685, -0.04764985, - -0.13873616, 0.040668588, 0.034832682, -0.015319203, - -0.018715994, 0.046002675, 0.0599172, -0.043107376, - 0.0294216, -0.002314414, -0.022424703, 0.0030315618, - 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, - 0.12375372, -0.0006038222, 0.029104086, 0.087442465, - 0.052958444, 0.07558703, 0.04817258, 0.044462286, - -0.015213451, -0.08783778, -0.0561384, -0.003008196, - 0.047060397, -0.002058388, 0.03429439, -0.018839769, - 0.024734668, 0.024614193, -0.042046934, 0.09597743, - -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, - -0.02558259, -0.022822596, -0.023273505, -0.02464396, - -0.10991725, -0.006240552, 0.0074488563, 0.024044557, - 0.04383914, -0.046476185, 0.028658995, 0.060410924, - 0.050786525, 0.009452605, -0.0073054377, -0.024810238, - 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, - 0.015898481, 0.021362653, -0.030262267, 0.016587038, - -0.011442813, 0.041154444, -0.007631438, -0.03423484, - -0.010977775, 0.036152758, 0.0066366293, 0.11915515, - 0.02318443, -0.041350313, 0.021485701, -0.10906167, - -0.028218046, -0.00954771, 0.020531068, -0.11995105, - -0.03672871, 0.024019798, 0.014255957, -0.05221243, - -0.00661567, -0.04630967, 0.033188973, 0.10107534, - -0.014027541, 0.030796422, -0.10270911, -0.035999842, - 0.15443139, 0.07684145, 0.036571592, -0.035900835, - -0.0034699554, 0.06209149, 0.015920248, -0.031122351, - -0.03858649, 0.01849943, 0.13872518, 0.01503974, - 0.069941424, -0.06948533, -0.0088794185, 0.061282158, - -0.047401894, 0.03100163, -0.041533746, -0.10430945, - 0.044574402, -0.01425562, -0.024290353, 0.034563623, - 0.05866852, 0.023947537, -0.09445152, 0.035450947, - 0.02247216, -0.0042998926, 0.061146557, -0.10250651, - 0.020881841, -0.06747029, 0.10062043, -0.0023941975, - 0.03532124, -0.016341697, 0.09685456, -0.016764693, - 0.051808182, 0.05875331, -0.04536488, 0.001626336, - -0.028892258, -0.01048663, -0.009793449, -0.017093895, - 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, - -0.001845119, -0.03551521, 0.0018358806, 0.05763657, - -0.01769146, 0.040995963, 0.02235177, -0.060430344, - 0.11475477, -0.023854522, 0.10071741, 0.0686208, - -0.014250481, 0.034261297, 0.047418304, 0.08562733, - -0.030519066, 0.0060542435, 0.014653856, -0.038836084, - 0.04096551, 0.032249358, -0.08355519, -0.026823482, - 0.056386515, -0.010401743, -0.028396193, 0.08507674, - 0.014410365, 0.020995233, 0.17040324, 0.11511526, - 0.02459721, 0.0066619175, 0.025853224, -0.023133837, - -0.081302024, 0.017264642, -0.009585969, 0.09491168, - -0.051313367, 0.054532815, -0.014298593, 0.10657464, - 0.007076659, 0.10964551, 0.0409152, 0.008275321, - -0.07283536, 0.07937492, 0.04192024, -0.1075027}); - - lstm.SetRecurrentToCellWeights( - {-0.037322544, 0.018592842, 0.0056175636, -0.06253426, - 0.055647098, -0.05713207, -0.05626563, 0.005559383, - 0.03375411, -0.025757805, -0.088049285, 0.06017052, - -0.06570978, 0.007384076, 0.035123326, -0.07920549, - 0.053676967, 0.044480428, -0.07663568, 0.0071805613, - 0.08089997, 0.05143358, 0.038261272, 0.03339287, - -0.027673481, 0.044746667, 0.028349208, 0.020090483, - -0.019443132, -0.030755889, -0.0040000007, 0.04465846, - -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, - -0.10893326, 0.076739706, -0.08509834, -0.027997585, - 0.037871376, 0.01449768, -0.09002357, -0.06111149, - -0.046195522, 0.0422062, -0.005683705, -0.1253618, - -0.012925729, -0.04890792, 0.06985068, 0.037654128, - 0.03398274, -0.004781977, 0.007032333, -0.031787455, - 0.010868644, -0.031489216, 0.09525667, 0.013939797, - 0.0058680447, 0.0167067, 0.02668468, -0.04797466, - -0.048885044, -0.12722108, 0.035304096, 0.06554885, - 0.00972396, -0.039238118, -0.05159735, -0.11329045, - 0.1613692, -0.03750952, 0.06529313, -0.071974665, - -0.11769596, 0.015524369, -0.0013754242, -0.12446318, - 0.02786344, -0.014179351, 0.005264273, 0.14376344, - 0.015983658, 0.03406988, -0.06939408, 0.040699873, - 0.02111075, 0.09669095, 0.041345075, -0.08316494, - -0.07684199, -0.045768797, 0.032298047, -0.041805092, - 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, - -0.024950314, 0.11574242, 0.04508852, -0.04335324, - 0.06760663, -0.027437469, 0.07216407, 0.06977076, - -0.05438599, 0.034033038, -0.028602652, 0.05346137, - 0.043184172, -0.037189785, 0.10420091, 0.00882477, - -0.054019816, -0.074273005, -0.030617684, -0.0028467078, - 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, - 0.04361412, -0.007001822, 0.09631092, -0.06702025, - -0.042049985, -0.035070654, -0.04103342, -0.10273396, - 0.0544271, 0.037184782, -0.13150354, -0.0058036847, - -0.008264958, 0.042035464, 0.05891794, 0.029673764, - 0.0063542654, 0.044788733, 0.054816857, 0.062257513, - -0.00093483756, 0.048938446, -0.004952862, -0.007730018, - -0.04043371, -0.017094059, 0.07229206, -0.023670016, - -0.052195564, -0.025616996, -0.01520939, 0.045104615, - -0.007376126, 0.003533447, 0.006570588, 0.056037236, - 0.12436656, 0.051817212, 0.028532185, -0.08686856, - 0.11868599, 0.07663395, -0.07323171, 0.03463402, - -0.050708205, -0.04458982, -0.11590894, 0.021273347, - 0.1251325, -0.15313013, -0.12224372, 0.17228661, - 0.023029093, 0.086124025, 0.006445803, -0.03496501, - 0.028332196, 0.04449512, -0.042436164, -0.026587414, - -0.006041347, -0.09292539, -0.05678812, 0.03897832, - 0.09465633, 0.008115513, -0.02171956, 0.08304309, - 0.071401566, 0.019622514, 0.032163795, -0.004167056, - 0.02295182, 0.030739572, 0.056506045, 0.004612461, - 0.06524936, 0.059999723, 0.046395954, -0.0045512207, - -0.1335546, -0.030136576, 0.11584653, -0.014678886, - 0.0020118146, -0.09688814, -0.0790206, 0.039770417, - -0.0329582, 0.07922767, 0.029322514, 0.026405897, - 0.04207835, -0.07073373, 0.063781224, 0.0859677, - -0.10925287, -0.07011058, 0.048005477, 0.03438226, - -0.09606514, -0.006669445, -0.043381985, 0.04240257, - -0.06955775, -0.06769346, 0.043903265, -0.026784198, - -0.017840602, 0.024307009, -0.040079936, -0.019946516, - 0.045318738, -0.12233574, 0.026170589, 0.0074471775, - 0.15978073, 0.10185836, 0.10298046, -0.015476589, - -0.039390966, -0.072174534, 0.0739445, -0.1211869, - -0.0347889, -0.07943156, 0.014809798, -0.12412325, - -0.0030663363, 0.039695457, 0.0647603, -0.08291318, - -0.018529687, -0.004423833, 0.0037507233, 0.084633216, - -0.01514876, -0.056505352, -0.012800942, -0.06994386, - 0.012962922, -0.031234352, 0.07029052, 0.016418684, - 0.03618972, 0.055686004, -0.08663945, -0.017404709, - -0.054761406, 0.029065743, 0.052404847, 0.020238016, - 0.0048197987, -0.0214882, 0.07078733, 0.013016777, - 0.06262858, 0.009184685, 0.020785125, -0.043904778, - -0.0270329, -0.03299152, -0.060088247, -0.015162964, - -0.001828936, 0.12642565, -0.056757294, 0.013586685, - 0.09232601, -0.035886683, 0.06000002, 0.05229691, - -0.052580316, -0.082029596, -0.010794592, 0.012947712, - -0.036429964, -0.085508935, -0.13127148, -0.017744139, - 0.031502828, 0.036232427, -0.031581745, 0.023051167, - -0.05325106, -0.03421577, 0.028793324, -0.034633752, - -0.009881397, -0.043551125, -0.018609839, 0.0019097115, - -0.008799762, 0.056595087, 0.0022273948, 0.055752404}); - - lstm.SetRecurrentToOutputWeights({ - 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415, - -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349, - -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948, - -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774, - -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125, - -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224, - -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088, - 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867, - -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728, - 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607, - -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928, - -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462, - 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879, - 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698, - -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146, - 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345, - 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166, - 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203, - 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743, - 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415, - -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618, - 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891, - -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015, - 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109, - 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886, - 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396, - -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282, - -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025, - -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575, - -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277, - -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719, - -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215, - 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483, - 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102, - -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775, - 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841, - -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656, - -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286, - -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309, - 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545, - 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754, - 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831, - -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697, - 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453, - -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222, - -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989, - -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827, - -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949, - 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819, - -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954, - 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228, - -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001, - -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939, - -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556, - -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718, - 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893, - 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974, - -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485, - 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856, - 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853, - -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019, - 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024, - 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994, - 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621, - }); - - lstm.SetCellToInputWeights( - {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, - -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, - -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, - 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}); - - lstm.SetCellToForgetWeights( - {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, - -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, - -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, - 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}); - - lstm.SetCellToOutputWeights( - {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, - -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, - -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, - 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}); - - lstm.SetProjectionWeights( - {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832, - 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683, - -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931, - -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476, - 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067, - 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787, - 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588, - 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285, - -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949, - -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768, - -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929, - 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504, - 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946, - 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117, - 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253, - 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456, - -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552, - 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797, - -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272, - 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165, - -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922, - -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548, - 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786, - -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722, - 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318, - -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776, - -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307, - 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969, - -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593, - -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515, - -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288, - 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723, - 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097, - -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209, - 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268, - 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139, - 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707, - 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871, - 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553, - -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702, - -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615, - 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187, - -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388, - -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709, - 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263, - 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777, - 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935, - -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641, - -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996, - -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318, - 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437, - -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079, - 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237, - 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415, - -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124, - -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943, - -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311, - 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013, - -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364, - -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543, - -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102, - 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906, - 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955, - 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656}); - - static float lstm_input[][20] = { - {// Batch0: 4 (input_sequence_size) * 5 (n_input) - 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386, - 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199, - 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, - - {// Batch1: 4 (input_sequence_size) * 5 (n_input) - 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260, - 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485, - 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}}; - - static float lstm_golden_output[][64] = { - {// Batch0: 4 (input_sequence_size) * 16 (n_output) - -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, - -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, - -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, - 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, - -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, - -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, - 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, - 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, - 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, - 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, - -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, - -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, - 0.0286833, 0.00824207, 0.0264887, 0.0305169}, - {// Batch1: 4 (input_sequence_size) * 16 (n_output) - -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, - -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, - 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, - 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, - -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, - -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, - 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, - 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, - 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, - 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, - -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, - -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, - 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToInputWeights(cell_to_input_weights_); + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + lstm.SetProjectionWeights(projection_weights_); // Resetting cell_state and output_state lstm.ResetCellState(); lstm.ResetOutputState(); - const int input_sequence_size = - sizeof(lstm_input[0]) / sizeof(float) / (lstm.num_inputs()); - for (int i = 0; i < input_sequence_size; i++) { - float* batch0_start = lstm_input[0] + i * lstm.num_inputs(); - float* batch0_end = batch0_start + lstm.num_inputs(); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} - lstm.SetInput(0, batch0_start, batch0_end); +TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 20; + const int n_output = 16; - float* batch1_start = lstm_input[1] + i * lstm.num_inputs(); - float* batch1_end = batch1_start + lstm.num_inputs(); - lstm.SetInput(lstm.num_inputs(), batch1_start, batch1_end); + HybridLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToInputWeights(cell_to_input_weights_); + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + lstm.SetProjectionWeights(projection_weights_); - lstm.Invoke(); + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); - float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs(); - float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs(); - float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs(); - float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs(); - std::vector expected; - expected.insert(expected.end(), golden_start_batch0, golden_end_batch0); - expected.insert(expected.end(), golden_start_batch1, golden_end_batch1); - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); - } + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467); } } // namespace diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc index 62f4e94a386fbbc6987e8a6dc1a9a47ce3349cbb..9e01b73c4933c34ce3fd549730080946674daaac 100644 --- a/tensorflow/contrib/lite/kernels/mul.cc +++ b/tensorflow/contrib/lite/kernels/mul.cc @@ -39,6 +39,14 @@ constexpr int kOutputTensor = 0; struct OpData { bool requires_broadcast; + + // Parameters used in the quantized paths where the output is 8bit + int32 output_activation_min; + int32 output_activation_max; + + // Parameters used in all quantized paths + int32_t output_multiplier; + int output_shift; }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -52,6 +60,7 @@ void Free(TfLiteContext* context, void* buffer) { } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); OpData* data = reinterpret_cast(node->user_data); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); @@ -62,7 +71,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TF_LITE_ENSURE_EQ(context, input1->type, input2->type); - output->type = input2->type; data->requires_broadcast = !HaveSameShapes(input1, input2); @@ -74,6 +82,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { output_size = TfLiteIntArrayCopy(input1->dims); } + if (output->type == kTfLiteUInt8) { + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + + if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) { + double real_multiplier = + input1->params.scale * input2->params.scale / output->params.scale; + QuantizeMultiplierSmallerThanOneExp( + real_multiplier, &data->output_multiplier, &data->output_shift); + data->output_shift *= -1; + } + return context->ResizeTensor(context, output, output_size); } @@ -107,41 +129,60 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, } template -void EvalQuantized(TfLiteContext* context, TfLiteNode* node, - TfLiteMulParams* params, const OpData* data, - const TfLiteTensor* input1, const TfLiteTensor* input2, - TfLiteTensor* output) { - auto input1_offset = -input1->params.zero_point; - auto input2_offset = -input2->params.zero_point; - auto output_offset = output->params.zero_point; - - int32_t output_multiplier; - int output_shift; - - double real_multiplier = - input1->params.scale * input2->params.scale / output->params.scale; - QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier, - &output_shift); - - int32 output_activation_min, output_activation_max; - CalculateActivationRangeUint8(params->activation, output, - &output_activation_min, &output_activation_max); - -#define TF_LITE_MUL(type, opname) \ - type::opname(GetTensorData(input1), GetTensorDims(input1), \ - input1_offset, GetTensorData(input2), \ - GetTensorDims(input2), input2_offset, output_offset, \ - output_multiplier, output_shift, output_activation_min, \ - output_activation_max, GetTensorData(output), \ +TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteMulParams* params, const OpData* data, + const TfLiteTensor* input1, + const TfLiteTensor* input2, TfLiteTensor* output) { + if (input1->type == kTfLiteUInt8 && input2->type == kTfLiteUInt8 && + output->type == kTfLiteUInt8) { +#define TF_LITE_MUL(type, opname) \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + -input1->params.zero_point, GetTensorData(input2), \ + GetTensorDims(input2), -input2->params.zero_point, \ + output->params.zero_point, data->output_multiplier, \ + data->output_shift, data->output_activation_min, \ + data->output_activation_max, GetTensorData(output), \ GetTensorDims(output)); - // The quantized version of Mul doesn't support activations, so we - // always use BroadcastMul. - if (kernel_type == kReference) { - TF_LITE_MUL(reference_ops, BroadcastMul); + // The quantized version of Mul doesn't support activations, so we + // always use BroadcastMul. + if (kernel_type == kReference) { + TF_LITE_MUL(reference_ops, BroadcastMul); + } else { + TF_LITE_MUL(optimized_ops, BroadcastMul); + } +#undef TF_LITE_MUL + } else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 && + output->type == kTfLiteInt16) { +#define TF_LITE_MUL(type, opname) \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + GetTensorData(output), GetTensorDims(output)); + if (kernel_type == kReference) { + TF_LITE_MUL(reference_ops, Mul); + } else { + TF_LITE_MUL(optimized_ops, Mul); + } +#undef TF_LITE_MUL + } else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 && + output->type == kTfLiteUInt8) { +#define TF_LITE_MUL(type, opname) \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output->params.zero_point, data->output_activation_min, \ + data->output_activation_max, GetTensorData(output), \ + GetTensorDims(output)); + if (kernel_type == kReference) { + TF_LITE_MUL(reference_ops, Mul); + } else { + TF_LITE_MUL(optimized_ops, Mul); + } +#undef TF_LITE_MUL } else { - TF_LITE_MUL(optimized_ops, BroadcastMul); + context->ReportError( + context, "Unsupported combination of input and output types in Mul."); + return kTfLiteError; } -#undef TF_LITE_MUL + return kTfLiteOk; } template @@ -155,12 +196,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (output->type == kTfLiteFloat32) { EvalFloat(context, node, params, data, input1, input2, output); - } else if (output->type == kTfLiteUInt8) { - EvalQuantized(context, node, params, data, input1, input2, - output); + } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) { + TF_LITE_ENSURE_OK( + context, EvalQuantized(context, node, params, data, input1, + input2, output)); } else { context->ReportError( - context, "Mul only supports FLOAT32 and quantized UINT8 now, got %d.", + context, + "Mul only supports FLOAT32 and quantized UINT8 and INT16 now, got %d.", output->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/mul_test.cc b/tensorflow/contrib/lite/kernels/mul_test.cc index f1a30f82634631ba8320421d5b36ffe446f443fa..43d56e50d2686ff2624f36a0c5d8e43279a572cc 100644 --- a/tensorflow/contrib/lite/kernels/mul_test.cc +++ b/tensorflow/contrib/lite/kernels/mul_test.cc @@ -58,6 +58,9 @@ class FloatMulOpModel : public BaseMulOpModel { const float kQuantizedStep = 2.0 / 255.0; const float kQuantizedTolerance = 2.0 * kQuantizedStep + kQuantizedStep * kQuantizedStep; +const float kQuantizedStepInt16 = 2.0 / 32767.0; +const float kQuantizedToleranceInt16 = + 2.0 * kQuantizedStepInt16 + kQuantizedStepInt16 * kQuantizedStepInt16; class QuantizedMulOpModel : public BaseMulOpModel { public: @@ -67,6 +70,11 @@ class QuantizedMulOpModel : public BaseMulOpModel { return Dequantize(ExtractVector(output_), GetScale(output_), GetZeroPoint(output_)); } + + std::vector GetDequantizedOutputInt16() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } }; TEST(FloatMulOpTest, NoActivation) { @@ -138,6 +146,38 @@ TEST(QuantizedMulOpTest, NoActivation) { kQuantizedTolerance))); } +TEST(QuantizedMulOpTest, NoActivationInt16) { + const float kMin = -1.f; + const float kMax = 32767.f / 32768.f; + QuantizedMulOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_INT16, {1, 2, 2, 1}, kMin, kMax}, + {TensorType_INT16, {}, kMin, kMax}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), {-0.8, 0.2, 0.9, 0.7}); + m.QuantizeAndPopulate(m.input2(), {0.6, 0.4, 0.9, 0.8}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutputInt16(), + ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56}, + kQuantizedToleranceInt16))); +} + +TEST(QuantizedMulOpTest, NoActivationInt16WithUint8Output) { + const float kMinInt16 = -1.f; + const float kMaxInt16 = 32767.f / 32768.f; + const float kMinUint8 = -1.f; + const float kMaxUint8 = 127.f / 128.f; + QuantizedMulOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMinInt16, kMaxInt16}, + {TensorType_INT16, {1, 2, 2, 1}, kMinInt16, kMaxInt16}, + {TensorType_UINT8, {}, kMinUint8, kMaxUint8}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), {-0.8, 0.2, 0.9, 0.7}); + m.QuantizeAndPopulate(m.input2(), {0.6, 0.4, 0.9, 0.8}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56}, + kQuantizedTolerance))); +} + // for quantized Mul, the error shouldn't exceed 2*step float GetTolerance(int min, int max) { float kQuantizedStep = (max - min) / 255.0; diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc index bcad58406af1cdd466e410a06011641692194be4..1c728a473326564a85a5e7d3d72718265979e29a 100644 --- a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc +++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc @@ -95,6 +95,12 @@ class LSTMOpModel : public SingleOpModel { projection_bias_ = AddNullInput(); } + // Adding the 2 input state tensors. + input_activation_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true); + input_cell_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true); + output_state_ = AddOutput(TensorType_FLOAT32); cell_state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); @@ -228,6 +234,8 @@ class LSTMOpModel : public SingleOpModel { int projection_weights_; int projection_bias_; + int input_activation_state_; + int input_cell_state_; int output_; int output_state_; diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc index 311e9b8399726d758182e1f084a890d6f10e57ce..41771e60bc6273c1b22fa7f3e996903202334306 100644 --- a/tensorflow/contrib/lite/kernels/pooling.cc +++ b/tensorflow/contrib/lite/kernels/pooling.cc @@ -126,12 +126,13 @@ void AverageEvalFloat(TfLiteContext* context, TfLiteNode* node, float activation_min, activation_max; CalculateActivationRangeFloat(params->activation, &activation_min, &activation_max); -#define TF_LITE_AVERAGE_POOL(type) \ - type::AveragePool( \ - GetTensorData(input), GetTensorDims(input), params->stride_width, \ - params->stride_height, data->padding.width, data->padding.height, \ - params->filter_width, params->filter_height, activation_min, \ - activation_max, GetTensorData(output), GetTensorDims(output)) +#define TF_LITE_AVERAGE_POOL(type) \ + type::AveragePool(GetTensorData(input), GetTensorShape(input), \ + params->stride_width, params->stride_height, \ + data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, \ + activation_min, activation_max, \ + GetTensorData(output), GetTensorShape(output)) if (kernel_type == kReference) { TF_LITE_AVERAGE_POOL(reference_ops); } else { @@ -148,13 +149,13 @@ void AverageEvalQuantized(TfLiteContext* context, TfLiteNode* node, int32_t activation_max; CalculateActivationRangeUint8(params->activation, output, &activation_min, &activation_max); -#define TF_LITE_AVERAGE_POOL(type) \ - type::AveragePool(GetTensorData(input), GetTensorDims(input), \ - params->stride_width, params->stride_height, \ - data->padding.width, data->padding.height, \ - params->filter_width, params->filter_height, \ - activation_min, activation_max, \ - GetTensorData(output), GetTensorDims(output)) +#define TF_LITE_AVERAGE_POOL(type) \ + type::AveragePool(GetTensorData(input), GetTensorShape(input), \ + params->stride_width, params->stride_height, \ + data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, \ + activation_min, activation_max, \ + GetTensorData(output), GetTensorShape(output)) if (kernel_type == kReference) { TF_LITE_AVERAGE_POOL(reference_ops); } else { @@ -170,12 +171,13 @@ void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node, float activation_min, activation_max; CalculateActivationRangeFloat(params->activation, &activation_min, &activation_max); -#define TF_LITE_MAX_POOL(type) \ - type::MaxPool( \ - GetTensorData(input), GetTensorDims(input), params->stride_width, \ - params->stride_height, data->padding.width, data->padding.height, \ - params->filter_width, params->filter_height, activation_min, \ - activation_max, GetTensorData(output), GetTensorDims(output)) +#define TF_LITE_MAX_POOL(type) \ + type::MaxPool(GetTensorData(input), GetTensorShape(input), \ + params->stride_width, params->stride_height, \ + data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, activation_min, \ + activation_max, GetTensorData(output), \ + GetTensorShape(output)) if (kernel_type == kReference) { TF_LITE_MAX_POOL(reference_ops); } else { @@ -193,12 +195,12 @@ void MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node, CalculateActivationRangeUint8(params->activation, output, &activation_min, &activation_max); #define TF_LITE_MAX_POOL(type) \ - type::MaxPool(GetTensorData(input), GetTensorDims(input), \ + type::MaxPool(GetTensorData(input), GetTensorShape(input), \ params->stride_width, params->stride_height, \ data->padding.width, data->padding.height, \ params->filter_width, params->filter_height, activation_min, \ activation_max, GetTensorData(output), \ - GetTensorDims(output)) + GetTensorShape(output)) if (kernel_type == kReference) { TF_LITE_MAX_POOL(reference_ops); } else { @@ -214,12 +216,13 @@ void L2EvalFloat(TfLiteContext* context, TfLiteNode* node, float activation_min, activation_max; CalculateActivationRangeFloat(params->activation, &activation_min, &activation_max); -#define TF_LITE_L2_POOL(type) \ - type::L2Pool( \ - GetTensorData(input), GetTensorDims(input), params->stride_width, \ - params->stride_height, data->padding.width, data->padding.height, \ - params->filter_width, params->filter_height, activation_min, \ - activation_max, GetTensorData(output), GetTensorDims(output)) +#define TF_LITE_L2_POOL(type) \ + type::L2Pool(GetTensorData(input), GetTensorShape(input), \ + params->stride_width, params->stride_height, \ + data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, activation_min, \ + activation_max, GetTensorData(output), \ + GetTensorShape(output)) if (kernel_type == kReference) { TF_LITE_L2_POOL(reference_ops); } else { diff --git a/tensorflow/contrib/lite/kernels/mean.cc b/tensorflow/contrib/lite/kernels/reduce.cc similarity index 72% rename from tensorflow/contrib/lite/kernels/mean.cc rename to tensorflow/contrib/lite/kernels/reduce.cc index 03e5db24de3f3c2d4e17df21bc0b592a02078d6b..31c331a8c61ded203af9ff2ae127cb6f985e2932 100644 --- a/tensorflow/contrib/lite/kernels/mean.cc +++ b/tensorflow/contrib/lite/kernels/reduce.cc @@ -25,21 +25,21 @@ limitations under the License. namespace tflite { namespace ops { namespace builtin { -namespace mean { +namespace reduce { -// This file has reference implementation of Mean. +// This file has reference implementation of reduce_* operators. enum KernelType { kReference, }; -struct MeanContext { - MeanContext(TfLiteContext* context, TfLiteNode* node) { - params = reinterpret_cast(node->builtin_data); +struct OpContext { + OpContext(TfLiteContext* context, TfLiteNode* node) { + params = reinterpret_cast(node->builtin_data); input = GetInput(context, node, 0); axis = GetInput(context, node, 1); output = GetOutput(context, node, 0); } - TfLiteMeanParams* params; + TfLiteReducerParams* params; const TfLiteTensor* input; const TfLiteTensor* axis; TfLiteTensor* output; @@ -58,7 +58,7 @@ void Free(TfLiteContext* context, void* buffer) { } // Resizes the temp tensor that stores resolved axis. -TfLiteStatus ResizeTempAxis(TfLiteContext* context, MeanContext* op_context, +TfLiteStatus ResizeTempAxis(TfLiteContext* context, OpContext* op_context, TfLiteTensor* resolved_axis) { TfLiteIntArray* axis_size = TfLiteIntArrayCreate(1); axis_size->data[0] = static_cast(NumElements(op_context->axis)); @@ -66,7 +66,7 @@ TfLiteStatus ResizeTempAxis(TfLiteContext* context, MeanContext* op_context, } // Resizes the temp tensor that stores temp sum of reduced elements. -TfLiteStatus ResizeTempSum(TfLiteContext* context, MeanContext* op_context, +TfLiteStatus ResizeTempSum(TfLiteContext* context, OpContext* op_context, TfLiteTensor* temp_sum) { TfLiteIntArray* size = TfLiteIntArrayCreate(1); size->data[0] = static_cast(NumElements(op_context->output)); @@ -74,8 +74,7 @@ TfLiteStatus ResizeTempSum(TfLiteContext* context, MeanContext* op_context, } // Resizes output array based on the input size and resolved axis. -TfLiteStatus ResizeOutputTensor(TfLiteContext* context, - MeanContext* op_context) { +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, OpContext* op_context) { size_t num_axis = NumElements(op_context->axis); const TfLiteIntArray* input_dims = op_context->input->dims; int input_num_dims = NumDimensions(op_context->input); @@ -140,7 +139,7 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, // Initializes temp tensors to store index and resolved axis. TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, - MeanContext* op_context) { + OpContext* op_context) { // Creates a temp index to iterate through input data. int* scratch_tensor_index = reinterpret_cast(node->user_data); TfLiteIntArrayFree(node->temporaries); @@ -180,33 +179,44 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, return kTfLiteOk; } -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { +TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - MeanContext op_context(context, node); + OpContext op_context(context, node); TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context)); TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); - TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2); // Leaves work to Eval if axis is not constant; else resizes output. if (!IsConstantTensor(op_context.axis)) { SetTensorToDynamic(op_context.output); SetTensorToDynamic(resolved_axis); - SetTensorToDynamic(temp_sum); return kTfLiteOk; } resolved_axis->allocation_type = kTfLiteArenaRw; TF_LITE_ENSURE_OK(context, ResizeTempAxis(context, &op_context, resolved_axis)); TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + return kTfLiteOk; +} + +TfLiteStatus PrepareMean(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_OK(context, PrepareSimple(context, node)); + + // reduce_mean requires a buffer to store intermediate sum result. + OpContext op_context(context, node); + TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2); + if (!IsConstantTensor(op_context.axis)) { + SetTensorToDynamic(temp_sum); + return kTfLiteOk; + } temp_sum->allocation_type = kTfLiteArenaRw; return ResizeTempSum(context, &op_context, temp_sum); } template -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - MeanContext op_context(context, node); +TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); int num_axis = static_cast(NumElements(op_context.axis)); TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); @@ -255,16 +265,75 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { #undef TF_LITE_MEAN return kTfLiteOk; } -} // namespace mean + +template +TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); + int num_axis = static_cast(NumElements(op_context.axis)); + TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); + TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); + // Resize the output tensor if the output tensor is dynamic. + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, + ResizeTempAxis(context, &op_context, resolved_axis)); + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + } + +#define TF_LITE_SUM(kernel_type, data_type) \ + kernel_type::Sum<>( \ + GetTensorData(op_context.input), \ + op_context.input->dims->data, op_context.input->dims->size, \ + GetTensorData(op_context.output), \ + op_context.output->dims->data, op_context.output->dims->size, \ + GetTensorData(op_context.axis), num_axis, \ + op_context.params->keep_dims, GetTensorData(temp_index), \ + GetTensorData(resolved_axis)) + + if (kernel_type == kReference) { + switch (op_context.input->type) { + case kTfLiteFloat32: + TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, float)); + break; + case kTfLiteInt32: + TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, int)); + break; + case kTfLiteInt64: + TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, int64_t)); + break; + case kTfLiteUInt8: + TF_LITE_ENSURE_EQ(context, op_context.input->params.scale, + op_context.output->params.scale); + TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point, + op_context.output->params.zero_point); + TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, uint8_t)); + break; + default: + return kTfLiteError; + } + } +#undef TF_LITE_SUM + return kTfLiteOk; +} + +} // namespace reduce TfLiteRegistration* Register_MEAN_REF() { - static TfLiteRegistration r = {mean::Init, mean::Free, mean::Prepare, - mean::Eval}; + static TfLiteRegistration r = {reduce::Init, reduce::Free, + reduce::PrepareMean, + reduce::EvalMean}; + return &r; +} + +TfLiteRegistration* Register_SUM_REF() { + static TfLiteRegistration r = {reduce::Init, reduce::Free, + reduce::PrepareSimple, + reduce::EvalSum}; return &r; } // TODO(kanlig): add optimized implementation of Mean. TfLiteRegistration* Register_MEAN() { return Register_MEAN_REF(); } +TfLiteRegistration* Register_SUM() { return Register_SUM_REF(); } } // namespace builtin } // namespace ops diff --git a/tensorflow/contrib/lite/kernels/mean_test.cc b/tensorflow/contrib/lite/kernels/reduce_test.cc similarity index 53% rename from tensorflow/contrib/lite/kernels/mean_test.cc rename to tensorflow/contrib/lite/kernels/reduce_test.cc index 79c9957f76fdb994be0a71f2e90b883435de4815..9e946822c686f6f20505d60b6161239624c94696 100644 --- a/tensorflow/contrib/lite/kernels/mean_test.cc +++ b/tensorflow/contrib/lite/kernels/reduce_test.cc @@ -23,7 +23,7 @@ namespace { using ::testing::ElementsAreArray; -class BaseMeanOpModel : public SingleOpModel { +class BaseOpModel : public SingleOpModel { public: void SetAxis(std::initializer_list data) { PopulateTensor(axis_, data); } @@ -53,7 +53,7 @@ class BaseMeanOpModel : public SingleOpModel { }; // Model for the tests case where axis is a const tensor. -class MeanOpConstModel : public BaseMeanOpModel { +class MeanOpConstModel : public BaseOpModel { public: MeanOpConstModel(const TensorData& input, const TensorData& output, std::initializer_list axis_shape, @@ -61,26 +61,59 @@ class MeanOpConstModel : public BaseMeanOpModel { input_ = AddInput(input); axis_ = AddConstInput(TensorType_INT32, axis, axis_shape); output_ = AddOutput(output); - SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_MeanOptions, - CreateMeanOptions(builder_, keep_dims).Union()); + SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); BuildInterpreter({GetShape(input_)}); } }; // Model for the tests case where axis is a dynamic tensor. -class MeanOpDynamicModel : public BaseMeanOpModel { +class MeanOpDynamicModel : public BaseOpModel { public: MeanOpDynamicModel(const TensorData& input, const TensorData& output, const TensorData& axis, bool keep_dims) { input_ = AddInput(input); axis_ = AddInput(axis); output_ = AddOutput(output); - SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_MeanOptions, - CreateMeanOptions(builder_, keep_dims).Union()); + SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); BuildInterpreter({GetShape(input_)}); } }; +// Model for the tests case where axis is a const tensor. +class SumOpConstModel : public BaseOpModel { + public: + SumOpConstModel(const TensorData& input, const TensorData& output, + std::initializer_list axis_shape, + std::initializer_list axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddConstInput(TensorType_INT32, axis, axis_shape); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_SUM, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + +// Model for the tests case where axis is a dynamic tensor. +class SumOpDynamicModel : public BaseOpModel { + public: + SumOpDynamicModel(const TensorData& input, const TensorData& output, + const TensorData& axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddInput(axis); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_SUM, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + +// for quantized Add, the error shouldn't exceed step +float GetTolerance(int min, int max) { return (max - min) / 255.0; } + +// Tests for reduce_mean TEST(ConstFloatMeanOpTest, NotKeepDims) { std::initializer_list data = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, @@ -149,8 +182,6 @@ TEST(DynamicFloatMeanOpTest, Scale) { EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({9.527}))); } -// for quantized Add, the error shouldn't exceed step -float GetTolerance(int min, int max) { return (max - min) / 255.0; } TEST(ConstUint8MeanOpTest, NotKeepDims) { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); @@ -209,6 +240,135 @@ TEST(DynamicUint8MeanOpTest, KeepDims) { ElementsAreArray(ArrayFloatNear({9.2815, 0.3695}, kQuantizedTolerance))); } +// Tests for reduce_sum + +TEST(ConstFloatSumOpTest, NotKeepDims) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + SumOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}}, + {4}, {1, 0, -3, -3}, false); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({144, 156}))); +} + +TEST(ConstFloatSumOpTest, KeepDims) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + SumOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}}, + {2}, {0, 2}, true); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({84, 100, 116}))); +} + +TEST(DynamicFloatSumOpTest, NotKeepDims) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + SumOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}}, + false); + std::initializer_list axis = {1, 0, -3, -3}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({144, 156}))); +} + +TEST(DynamicFloatSumOpTest, KeepDims) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + SumOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}}, true); + std::initializer_list axis = {0, 2}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({84, 100, 116}))); +} + +TEST(DynamicFloatSumOpTest, Scale) { + std::initializer_list data = {9.527}; + SumOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}}, + {TensorType_INT32, {1}}, true); + std::initializer_list axis = {0}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({9.527}))); +} + +TEST(ConstUint8SumOpTest, NotKeepDims) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::initializer_list data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; + SumOpConstModel m({TensorType_UINT8, {1, 3, 2}, -1.0, 1.0}, + {TensorType_UINT8, {2}, -1.0, 1.0}, {1}, {1}, false); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray( + ArrayFloatNear({-0.823529, -0.815686}, kQuantizedTolerance))); +} + +TEST(ConstUint8SumOpTest, KeepDims) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::initializer_list data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; + SumOpConstModel m({TensorType_UINT8, {3, 2}, -1.0, 1.0}, + {TensorType_UINT8, {3}, -1.0, 1.0}, {1}, {1}, true); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({-0.407843, -0.313726, 0.0941177}, + kQuantizedTolerance))); +} + +TEST(DynamicUint8SumOpTest, NotKeepDims) { + float kQuantizedTolerance = GetTolerance(-5.0, 2.0); + std::initializer_list data = {1.3, -4.8, -3.6, 0.24}; + SumOpDynamicModel m({TensorType_UINT8, {2, 2}, -5.0, 2.0}, + {TensorType_UINT8, {2}, -5.0, 2.0}, + {TensorType_INT32, {1}}, false); + std::initializer_list axis = {1}; + m.SetAxis(axis); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray( + ArrayFloatNear({1.48235, 1.64706}, kQuantizedTolerance))); +} + +TEST(DynamicUint8SumOpTest, KeepDims) { + float kQuantizedTolerance = GetTolerance(-10.0, 12.0); + std::initializer_list data = {11.14, -0.14, 7.423, 0.879}; + SumOpDynamicModel m({TensorType_UINT8, {2, 2}, -10.0, 12.0}, + {TensorType_UINT8, {2}, -10.0, 12.0}, + {TensorType_INT32, {1}}, true); + std::initializer_list axis = {0}; + m.SetAxis(axis); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({6.47059, 10.698}, kQuantizedTolerance))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 184b02dcecc580db51a92a1987525ad7e54ee010..67f6caea678f840076f839a2203d047d1e63329d 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -22,6 +22,7 @@ namespace custom { TfLiteRegistration* Register_AUDIO_SPECTROGRAM(); TfLiteRegistration* Register_MFCC(); +TfLiteRegistration* Register_DETECTION_POSTPROCESS(); } // namespace custom @@ -73,6 +74,7 @@ TfLiteRegistration* Register_SQUEEZE(); TfLiteRegistration* Register_STRIDED_SLICE(); TfLiteRegistration* Register_EXP(); TfLiteRegistration* Register_TOPK_V2(); +TfLiteRegistration* Register_LOG(); TfLiteRegistration* Register_LOG_SOFTMAX(); TfLiteRegistration* Register_CAST(); TfLiteRegistration* Register_DEQUANTIZE(); @@ -87,12 +89,18 @@ TfLiteRegistration* Register_LESS_EQUAL(); TfLiteRegistration* Register_FLOOR(); TfLiteRegistration* Register_TILE(); TfLiteRegistration* Register_NEG(); +TfLiteRegistration* Register_SUM(); TfLiteRegistration* Register_SELECT(); TfLiteRegistration* Register_SLICE(); TfLiteRegistration* Register_SIN(); TfLiteRegistration* Register_TRANSPOSE_CONV(); TfLiteRegistration* Register_EXPAND_DIMS(); TfLiteRegistration* Register_SPARSE_TO_DENSE(); +TfLiteRegistration* Register_EQUAL(); +TfLiteRegistration* Register_NOT_EQUAL(); +TfLiteRegistration* Register_SQRT(); +TfLiteRegistration* Register_RSQRT(); +TfLiteRegistration* Register_SHAPE(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -148,6 +156,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE()); AddBuiltin(BuiltinOperator_EXP, Register_EXP()); AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2()); + AddBuiltin(BuiltinOperator_LOG, Register_LOG()); AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX()); AddBuiltin(BuiltinOperator_CAST, Register_CAST()); AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE()); @@ -166,14 +175,22 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_SIN, Register_SIN()); AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV()); AddBuiltin(BuiltinOperator_TILE, Register_TILE()); + AddBuiltin(BuiltinOperator_SUM, Register_SUM()); AddBuiltin(BuiltinOperator_EXPAND_DIMS, Register_EXPAND_DIMS()); AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE()); + AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL()); + AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL()); + AddBuiltin(BuiltinOperator_SQRT, Register_SQRT()); + AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT()); + AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. AddCustom("Mfcc", tflite::ops::custom::Register_MFCC()); AddCustom("AudioSpectrogram", tflite::ops::custom::Register_AUDIO_SPECTROGRAM()); + AddCustom("TFLite_Detection_PostProcess", + tflite::ops::custom::Register_DETECTION_POSTPROCESS()); } } // namespace builtin diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h index b928f1b302580d52f708bbf85dfcfc0f79ff1e69..940718d67e70b7206227b891ea529cb9e9619161 100644 --- a/tensorflow/contrib/lite/kernels/register.h +++ b/tensorflow/contrib/lite/kernels/register.h @@ -32,4 +32,4 @@ class BuiltinOpResolver : public MutableOpResolver { } // namespace ops } // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_BUILTIN_KERNELS_H +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc index f2092eaa36db32ebbc959ac23365bb13dd034e68..86c4cd3ee88013ca4174f444d0388bc036d9cde6 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc @@ -61,12 +61,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1); - // TODO(ahentz): Our current implementations only support float32. - TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, size->type, kTfLiteInt32); // ResizeBilinear creates a float tensor even when the input is made of // integers. - output->type = kTfLiteFloat32; + output->type = input->type; if (!IsConstantTensor(size)) { SetTensorToDynamic(output); @@ -90,17 +88,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } if (output->type == kTfLiteFloat32) { -#define TF_LITE_RESIZE_BILINEAR(type) \ - type::ResizeBilinear(GetTensorData(input), GetTensorDims(input), \ - GetTensorData(size), GetTensorDims(size), \ - GetTensorData(output), GetTensorDims(output), \ +#define TF_LITE_RESIZE_BILINEAR(type, datatype) \ + type::ResizeBilinear(GetTensorData(input), GetTensorDims(input), \ + GetTensorData(size), GetTensorDims(size), \ + GetTensorData(output), GetTensorDims(output), \ params->align_corners) if (kernel_type == kReference) { - TF_LITE_RESIZE_BILINEAR(reference_ops); + TF_LITE_RESIZE_BILINEAR(reference_ops, float); } if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) { - TF_LITE_RESIZE_BILINEAR(optimized_ops); + TF_LITE_RESIZE_BILINEAR(optimized_ops, float); + } + } else if (output->type == kTfLiteUInt8) { + if (kernel_type == kReference) { + TF_LITE_RESIZE_BILINEAR(reference_ops, uint8_t); + } + if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) { + TF_LITE_RESIZE_BILINEAR(optimized_ops, uint8_t); } #undef TF_LITE_RESIZE_BILINEAR } else { diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc index 4e03f3820a5c14ee1692c553db61e385716b1723..10caffea03ebcec7862df1627541ac3d076b04e4 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc @@ -22,6 +22,7 @@ namespace tflite { namespace { using ::testing::ElementsAreArray; +using uint8 = std::uint8_t; class ResizeBilinearOpModel : public SingleOpModel { public: @@ -34,7 +35,7 @@ class ResizeBilinearOpModel : public SingleOpModel { } else { size_ = AddInput({TensorType_INT32, {2}}); } - output_ = AddOutput(TensorType_FLOAT32); // Always float. + output_ = AddOutput(input.type); SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR, BuiltinOptions_ResizeBilinearOptions, CreateResizeBilinearOptions(builder_).Union()); @@ -45,12 +46,16 @@ class ResizeBilinearOpModel : public SingleOpModel { } } - void SetInput(std::initializer_list data) { + template + void SetInput(std::initializer_list data) { PopulateTensor(input_, data); } void SetSize(std::initializer_list data) { PopulateTensor(size_, data); } - std::vector GetOutput() { return ExtractVector(output_); } + template + std::vector GetOutput() { + return ExtractVector(output_); + } private: int input_; @@ -60,60 +65,121 @@ class ResizeBilinearOpModel : public SingleOpModel { TEST(ResizeBilinearOpTest, HorizontalResize) { ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}); - m.SetInput({3, 6}); + m.SetInput({3, 6}); m.SetSize({1, 3}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6}))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3}); - const_m.SetInput({3, 6}); + const_m.SetInput({3, 6}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); +} + +TEST(ResizeBilinearOpTest, HorizontalResize8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {1, 1, 2, 1}}); + m.SetInput({3, 6}); + m.SetSize({1, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 1, 2, 1}}, {1, 3}); + const_m.SetInput({3, 6}); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6}))); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); } TEST(ResizeBilinearOpTest, VerticalResize) { ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}); - m.SetInput({3, 9}); + m.SetInput({3, 9}); m.SetSize({3, 1}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9}))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1}); - const_m.SetInput({3, 9}); + const_m.SetInput({3, 9}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); +} + +TEST(ResizeBilinearOpTest, VerticalResize8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 1, 1}}); + m.SetInput({3, 9}); + m.SetSize({3, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 1, 1}}, {3, 1}); + const_m.SetInput({3, 9}); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9}))); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); } TEST(ResizeBilinearOpTest, TwoDimensionalResize) { ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}); - m.SetInput({ + m.SetInput({ 3, 6, // 9, 12 // }); m.SetSize({3, 3}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 5, 6, // - 7, 9, 10, // - 9, 11, 12, // - }))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3}); - const_m.SetInput({ + const_m.SetInput({ 3, 6, // 9, 12 // }); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 5, 6, // - 7, 9, 10, // - 9, 11, 12, // - }))); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); +} + +TEST(ResizeBilinearOpTest, TwoDimensionalResize8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 1}}); + m.SetInput({ + 3, 6, // + 9, 12 // + }); + m.SetSize({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, 6, // + 9, 12 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); } TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { ResizeBilinearOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}); - m.SetInput({ + m.SetInput({ 3, 6, // 9, 12, // 4, 10, // @@ -121,60 +187,123 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { }); m.SetSize({3, 3}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 5, 6, // - 7, 9, 10, // - 9, 11, 12, // - 4, 8, 10, // - 8, 12, 14, // - 10, 14, 16, // - }))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 14, 16, // + }))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3}); - const_m.SetInput({ + const_m.SetInput({ 3, 6, // 9, 12, // 4, 10, // 10, 16 // }); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 5, 6, // - 7, 9, 10, // - 9, 11, 12, // - 4, 8, 10, // - 8, 12, 14, // - 10, 14, 16, // - }))); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 14, 16, // + }))); } TEST(ResizeBilinearOpTest, ThreeDimensionalResize) { ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}); - m.SetInput({ + m.SetInput({ 3, 4, 6, 10, // 9, 10, 12, 16, // }); m.SetSize({3, 3}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 4, 5, 8, 6, 10, // - 7, 8, 9, 12, 10, 14, // - 9, 10, 11, 14, 12, 16, // - }))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 14, 12, 16, // + }))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 2}}, {3, 3}); - const_m.SetInput({ + const_m.SetInput({ 3, 4, 6, 10, // 9, 10, 12, 16, // }); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 4, 5, 8, 6, 10, // - 7, 8, 9, 12, 10, 14, // - 9, 10, 11, 14, 12, 16, // - }))); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 14, 12, 16, // + }))); +} + +TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {2, 2, 2, 1}}); + m.SetInput({ + 3, 6, // + 9, 12, // + 4, 10, // + 10, 16 // + }); + m.SetSize({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 13, 16, // + }))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {2, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, 6, // + 9, 12, // + 4, 10, // + 10, 16 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 13, 16, // + }))); } +TEST(ResizeBilinearOpTest, ThreeDimensionalResize8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 2}}); + m.SetInput({ + 3, 4, 6, 10, // + 9, 10, 12, 16, // + }); + m.SetSize({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 13, 12, 16, // + }))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 2}}, {3, 3}); + const_m.SetInput({ + 3, 4, 6, 10, // + 9, 10, 12, 16, // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 13, 12, 16, // + }))); +} } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/shape.cc b/tensorflow/contrib/lite/kernels/shape.cc new file mode 100644 index 0000000000000000000000000000000000000000..dbcd2ef004f490f00193153be7a2cfda83e73c24 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/shape.cc @@ -0,0 +1,93 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace shape { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +template +void ExtractShape(const TfLiteTensor* input, OutType* output_data) { + for (int i = 0; i < NumDimensions(input); ++i) { + output_data[i] = SizeOfDimension(input, i); + } +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + auto* params = reinterpret_cast(node->builtin_data); + switch (params->out_type) { + case kTfLiteInt32: + output->type = kTfLiteInt32; + break; + case kTfLiteInt64: + output->type = kTfLiteInt64; + break; + default: + context->ReportError(context, "Unknown shape output data type: %d", + params->out_type); + return kTfLiteError; + } + + // Shape always produces a 1-dimensional output tensor, where each output + // element is the length of the corresponding input tensor's dimension. + TfLiteIntArray* output_size = TfLiteIntArrayCreate(1); + output_size->data[0] = NumDimensions(input); + return context->ResizeTensor(context, output, output_size); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TFLITE_DCHECK_EQ(NumDimensions(output), 1); + TFLITE_DCHECK_EQ(SizeOfDimension(output, 0), NumDimensions(input)); + + switch (output->type) { + case kTfLiteInt32: + ExtractShape(input, GetTensorData(output)); + break; + case kTfLiteInt64: + ExtractShape(input, GetTensorData(output)); + break; + default: + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace shape + +TfLiteRegistration* Register_SHAPE() { + static TfLiteRegistration r = {nullptr, nullptr, shape::Prepare, shape::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/shape_test.cc b/tensorflow/contrib/lite/kernels/shape_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..27b48f4e992a8f02d56815bd1bd9074f5b41f400 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/shape_test.cc @@ -0,0 +1,95 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +template +class ShapeOpModel : public SingleOpModel { + public: + ShapeOpModel(std::initializer_list input_shape, TensorType input_type, + TensorType output_type) { + input_ = AddInput(input_type); + output_ = AddOutput(output_type); + SetBuiltinOp(BuiltinOperator_SHAPE, BuiltinOptions_ShapeOptions, + CreateShapeOptions(builder_, output_type).Union()); + BuildInterpreter({input_shape}); + } + + TfLiteStatus InvokeWithResult() { return interpreter_->Invoke(); } + + int input() { return input_; } + + int32_t GetOutputSize() { return GetTensorSize(output_); } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(ShapeOpTest, OutTypeInt) { + ShapeOpModel model({1, 3, 1, 3, 5}, TensorType_FLOAT32, + TensorType_INT32); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({5})); +} + +TEST(ShapeOpTest, OutTypeInt64) { + ShapeOpModel model({1, 3, 1, 3, 5}, TensorType_FLOAT32, + TensorType_INT64); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({5})); +} + +TEST(ShapeOpTest, ScalarTensor) { + ShapeOpModel model({}, TensorType_FLOAT32, TensorType_INT32); + model.Invoke(); + + EXPECT_EQ(model.GetOutputSize(), 0); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({0})); +} + +TEST(ShapeOpTest, EmptyTensor) { + ShapeOpModel model({1, 0}, TensorType_FLOAT32, TensorType_INT32); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0})); + EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/softmax_test.cc b/tensorflow/contrib/lite/kernels/softmax_test.cc index 6c5338ff0fd26337c9adc8e0b94a0a88edfde37f..727822f6beaa8a63ca8f1b57ba4993d2e59f7e0b 100644 --- a/tensorflow/contrib/lite/kernels/softmax_test.cc +++ b/tensorflow/contrib/lite/kernels/softmax_test.cc @@ -92,10 +92,9 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaEq1) { m.Invoke(); std::unique_ptr output_buffer(new float[input_size * batch_size]); - static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size}, - {1, 0, 0, input_size}}; - tflite::reference_ops::Softmax(input_buffer, input_dims, beta, - output_buffer.get(), input_dims); + auto input_shape = RuntimeShape({batch_size, 1, 1, input_size}); + tflite::reference_ops::Softmax(input_buffer, input_shape, beta, + output_buffer.get(), input_shape); std::vector expected; expected.insert(expected.end(), output_buffer.get(), @@ -120,10 +119,9 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaNotEq1) { m.Invoke(); std::unique_ptr output_buffer(new float[input_size * batch_size]); - static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size}, - {1, 0, 0, input_size}}; - tflite::reference_ops::Softmax(input_buffer, input_dims, beta, - output_buffer.get(), input_dims); + auto input_shape = RuntimeShape({batch_size, 1, 1, input_size}); + tflite::reference_ops::Softmax(input_buffer, input_shape, beta, + output_buffer.get(), input_shape); std::vector expected; expected.insert(expected.end(), output_buffer.get(), diff --git a/tensorflow/contrib/lite/kernels/split.cc b/tensorflow/contrib/lite/kernels/split.cc index 43387df9ceb4d54a2784c3fa4718a95262948729..b14448604123253bac9c50c21f047891721ab122 100644 --- a/tensorflow/contrib/lite/kernels/split.cc +++ b/tensorflow/contrib/lite/kernels/split.cc @@ -76,8 +76,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumOutputs(node), op_context.params->num_splits); auto input_type = op_context.input->type; - TF_LITE_ENSURE(context, - input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8); + TF_LITE_ENSURE(context, input_type == kTfLiteFloat32 || + input_type == kTfLiteUInt8 || + input_type == kTfLiteInt16); for (int i = 0; i < NumOutputs(node); ++i) { GetOutput(context, node, i)->type = input_type; } @@ -137,9 +138,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_SPLIT(uint8_t); break; } + case kTfLiteInt16: { + TF_LITE_SPLIT(int16_t); + break; + } default: context->ReportError( - context, "Only float32 and uint8 are currently supported, got %d.", + context, + "Only float32, uint8 and int16 are currently supported, got %d.", op_context.input->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc index d788159a8d80e6479024b7b75624839387a461c7..a8b803589962032db3ed579d31e8b736c3afada0 100644 --- a/tensorflow/contrib/lite/kernels/sub.cc +++ b/tensorflow/contrib/lite/kernels/sub.cc @@ -126,16 +126,19 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, int32 input1_multiplier; int input1_shift; - QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier, - &input1_shift); + QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, + &input1_multiplier, &input1_shift); + input1_shift *= -1; int32 input2_multiplier; int input2_shift; - QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier, - &input2_shift); + QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, + &input2_multiplier, &input2_shift); + input2_shift *= -1; int32 output_multiplier; int output_shift; - QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier, - &output_shift); + QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, + &output_multiplier, &output_shift); + output_shift *= -1; int32 output_activation_min, output_activation_max; CalculateActivationRangeUint8(params->activation, output, @@ -175,7 +178,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { output); } else { context->ReportError( - context, "output type %d is not support, requires float|uint8 types.", + context, "output type %d is not supported, requires float|uint8 types.", output->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc index d23ec201b41887b0682242687fc938d76d058c44..9156917140b5af6c0f38c878ab77fef7f93b049a 100644 --- a/tensorflow/contrib/lite/kernels/test_util.cc +++ b/tensorflow/contrib/lite/kernels/test_util.cc @@ -32,8 +32,8 @@ std::vector> ArrayFloatNear(const std::vector& values, return matchers; } -int SingleOpModel::AddInput(const TensorData& t) { - int id = AddTensor(t, {}); +int SingleOpModel::AddInput(const TensorData& t, bool is_variable) { + int id = AddTensor(t, {}, is_variable); inputs_.push_back(id); return id; } @@ -120,6 +120,7 @@ void SingleOpModel::BuildInterpreter( CHECK(interpreter_->AllocateTensors() == kTfLiteOk) << "Cannot allocate tensors"; + interpreter_->ResetVariableTensorsToZero(); } void SingleOpModel::Invoke() { CHECK(interpreter_->Invoke() == kTfLiteOk); } diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h index db80c0082c394a2cb2f9388d3db5bd1a7cbe6266..5094e1343aa7b31537333ef5a770dbbfe3954f55 100644 --- a/tensorflow/contrib/lite/kernels/test_util.h +++ b/tensorflow/contrib/lite/kernels/test_util.h @@ -126,8 +126,10 @@ class SingleOpModel { SingleOpModel& operator=(const SingleOpModel&) = delete; // Add a TensorType input tensor and return its index. - int AddInput(TensorType type) { return AddInput(TensorData{type}); } - int AddInput(const TensorData& t); + int AddInput(TensorType type, bool is_variable = false) { + return AddInput(TensorData{type}, is_variable); + } + int AddInput(const TensorData& t, bool is_variable = false); // Templated version of AddConstInput(). template @@ -260,7 +262,8 @@ class SingleOpModel { } template - int AddTensor(TensorData t, std::initializer_list data) { + int AddTensor(TensorData t, std::initializer_list data, + bool is_variable = false) { int id = tensors_.size(); // This is slightly different depending on whether we are adding a @@ -277,6 +280,9 @@ class SingleOpModel { } else if (t.type == TensorType_INT32) { std::tie(t.scale, t.zero_point) = QuantizationParams(t.min, t.max); + } else if (t.type == TensorType_INT16) { + std::tie(t.scale, t.zero_point) = + QuantizationParams(t.min, t.max); } else { LOG(FATAL) << "No support for the requested quantized type"; } @@ -309,7 +315,7 @@ class SingleOpModel { tensors_.push_back(CreateTensor(builder_, builder_.CreateVector(t.shape), t.type, /*buffer=*/buffer_id, - /*name=*/0, q_params)); + /*name=*/0, q_params, is_variable)); tensor_data_[id] = t; diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc index 3c99661029ed1ac881536f83519dcec355c60d50..8b9deeed20d761876d526c07eb78b602ca7314dc 100644 --- a/tensorflow/contrib/lite/kernels/transpose_conv.cc +++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc @@ -79,7 +79,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Ensure that weights and inputs have the same channel dimension. // Note: TOCO will reorder weights in the following format: OHWI. TF_LITE_ENSURE_EQ(context, SizeOfDimension(input, 3), - SizeOfDimension(weights, 0)); + SizeOfDimension(weights, 3)); if (!IsConstantTensor(output_shape)) { SetTensorToDynamic(output); @@ -119,10 +119,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Currently only support float32. switch (input->type) { case kTfLiteFloat32: - optimized_ops::TransposeConv( + reference_ops::TransposeConv( GetTensorData(input), GetTensorDims(input), GetTensorData(weights), GetTensorDims(weights), stride_width, stride_height, padding_size.width, padding_size.height, + GetTensorData(output), GetTensorDims(output), + // Last two args specify im2col which reference_ops ignores. + // (Note this does not lead to a performance regression, as the + // previous optimized version was just a copy of the reference code.) + // TODO(b/110208176): Allocate im2col tensors and switch to + // optimized_ops. GetTensorData(output), GetTensorDims(output)); break; default: diff --git a/tensorflow/contrib/lite/kernels/transpose_conv_test.cc b/tensorflow/contrib/lite/kernels/transpose_conv_test.cc index 52be08934997f484337e4a3592bc7af832601695..55df8971806ed0baae9f5bcaebd24fb8065ec300 100644 --- a/tensorflow/contrib/lite/kernels/transpose_conv_test.cc +++ b/tensorflow/contrib/lite/kernels/transpose_conv_test.cc @@ -88,10 +88,10 @@ TEST(TransposeConvOpModelTest, SimpleTest) { // And filter value is derived by: // filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[18, 1]) TEST(TransposeConvOpModelTest, TwoFiltersTest) { - TransposeConvOpModel m({1, 4, 4, 2}, {2, 3, 3, 1}, Padding_SAME, 1, 1); + TransposeConvOpModel m({1, 4, 4, 2}, {1, 3, 3, 2}, Padding_SAME, 1, 1); m.PopulateTensor(m.output_shape(), {1, 4, 4, 1}); - m.PopulateTensor(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, - 8, 10, 12, 14, 16, 18}); + m.PopulateTensor(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18}); m.PopulateTensor( m.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, @@ -117,10 +117,10 @@ TEST(TransposeConvOpModelTest, TwoFiltersTest) { // And filter value is derived by: // filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[1, 18]) TEST(TransposeConvOpModelTest, PaddingValidTest) { - TransposeConvOpModel m({1, 4, 4, 2}, {2, 3, 3, 1}, Padding_VALID, 1, 1); + TransposeConvOpModel m({1, 4, 4, 2}, {1, 3, 3, 2}, Padding_VALID, 1, 1); m.PopulateTensor(m.output_shape(), {1, 6, 6, 1}); - m.PopulateTensor(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, - 8, 10, 12, 14, 16, 18}); + m.PopulateTensor(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18}); m.PopulateTensor( m.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, @@ -171,10 +171,10 @@ TEST(TransposeConvOpModelTest, StrideValidTest) { // [1, 2, 2, 1 ], // "VALID") TEST(TransposeConvOpModelTest, MultiChannelTest) { - TransposeConvOpModel m({1, 2, 2, 1}, {1, 3, 3, 2}, Padding_VALID, 2, 2); + TransposeConvOpModel m({1, 2, 2, 1}, {2, 3, 3, 1}, Padding_VALID, 2, 2); m.PopulateTensor(m.output_shape(), {1, 5, 5, 2}); - m.PopulateTensor(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16, 17, 18}); + m.PopulateTensor(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, + 8, 10, 12, 14, 16, 18}); m.PopulateTensor(m.input(), {1, 2, 3, 4}); m.Invoke(); diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 8d8d74adfb5c40b41f9cdd22d1055d8689727278..e1ec2d6d5789da1f5d31981501a503d7b610b336 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -45,6 +45,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, case TensorType_FLOAT32: *type = kTfLiteFloat32; break; + case TensorType_INT16: + *type = kTfLiteInt16; + break; case TensorType_INT32: *type = kTfLiteInt32; break; @@ -322,12 +325,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = nullptr; switch (op_type) { - case BuiltinOperator_CALL: - // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are - // ok for now, since there is no call implementation either. - break; - case BuiltinOperator_CUSTOM: - break; case BuiltinOperator_CONV_2D: { TfLiteConvParams* params = MallocPOD(); if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) { @@ -343,21 +340,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_TANH: - case BuiltinOperator_LOGISTIC: - case BuiltinOperator_RELU: - case BuiltinOperator_RELU_N1_TO_1: - case BuiltinOperator_RELU6: - case BuiltinOperator_CONCAT_EMBEDDINGS: - case BuiltinOperator_EXP: - case BuiltinOperator_TOPK_V2: - case BuiltinOperator_LOG_SOFTMAX: - case BuiltinOperator_DEQUANTIZE: - case BuiltinOperator_PRELU: - case BuiltinOperator_FLOOR: - case BuiltinOperator_NEG: - case BuiltinOperator_SIN: - break; case BuiltinOperator_CAST: { TfLiteCastParams* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_CastOptions()) { @@ -445,9 +427,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_EMBEDDING_LOOKUP: - // no-op. - break; case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: { TfLiteEmbeddingLookupSparseParams* params = MallocPOD(); @@ -579,12 +558,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_PAD: { - break; - } - case BuiltinOperator_PADV2: { - break; - } case BuiltinOperator_RESHAPE: { auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) { @@ -624,18 +597,10 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_SPACE_TO_BATCH_ND: { - break; - } - case BuiltinOperator_BATCH_TO_SPACE_ND: { - break; - } - case BuiltinOperator_TRANSPOSE: { - break; - } - case BuiltinOperator_MEAN: { - auto* params = MallocPOD(); - if (auto* schema_params = op->builtin_options_as_MeanOptions()) { + case BuiltinOperator_MEAN: + case BuiltinOperator_SUM: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_ReducerOptions()) { params->keep_dims = schema_params->keep_dims(); } *builtin_data = reinterpret_cast(params); @@ -672,10 +637,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_MAXIMUM: - case BuiltinOperator_MINIMUM: { - break; - } case BuiltinOperator_ARG_MAX: { auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) { @@ -685,16 +646,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_GREATER: - case BuiltinOperator_GREATER_EQUAL: - case BuiltinOperator_LESS: - case BuiltinOperator_LESS_EQUAL: - case BuiltinOperator_SELECT: { - break; - } - case BuiltinOperator_SLICE: { - break; - } case BuiltinOperator_TRANSPOSE_CONV: { TfLiteTransposeConvParams* params = MallocPOD(); @@ -717,15 +668,62 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_SHAPE: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_ShapeOptions()) { + ConvertTensorType(schema_params->out_type(), ¶ms->out_type, + error_reporter); + } + *builtin_data = static_cast(params); + break; + } case BuiltinOperator_DELEGATE: { // TODO(ycling): Revisit when supporting saving delegated models. error_reporter->Report("DELEGATE op shouldn't exist in model."); return kTfLiteError; } + + // Below are the ops with no builtin_data strcture. + case BuiltinOperator_BATCH_TO_SPACE_ND: + // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are + // ok for now, since there is no call implementation either. + case BuiltinOperator_CALL: + case BuiltinOperator_CONCAT_EMBEDDINGS: + case BuiltinOperator_CUSTOM: + case BuiltinOperator_DEQUANTIZE: + case BuiltinOperator_EMBEDDING_LOOKUP: + case BuiltinOperator_EQUAL: + case BuiltinOperator_EXP: case BuiltinOperator_EXPAND_DIMS: - case BuiltinOperator_TILE: { + case BuiltinOperator_FLOOR: + case BuiltinOperator_GREATER: + case BuiltinOperator_GREATER_EQUAL: + case BuiltinOperator_LESS: + case BuiltinOperator_LESS_EQUAL: + case BuiltinOperator_LOG: + case BuiltinOperator_LOGISTIC: + case BuiltinOperator_LOG_SOFTMAX: + case BuiltinOperator_MAXIMUM: + case BuiltinOperator_MINIMUM: + case BuiltinOperator_NEG: + case BuiltinOperator_NOT_EQUAL: + case BuiltinOperator_PAD: + case BuiltinOperator_PADV2: + case BuiltinOperator_PRELU: + case BuiltinOperator_RELU: + case BuiltinOperator_RELU6: + case BuiltinOperator_RELU_N1_TO_1: + case BuiltinOperator_RSQRT: + case BuiltinOperator_SELECT: + case BuiltinOperator_SIN: + case BuiltinOperator_SLICE: + case BuiltinOperator_SPACE_TO_BATCH_ND: + case BuiltinOperator_SQRT: + case BuiltinOperator_TANH: + case BuiltinOperator_TILE: + case BuiltinOperator_TOPK_V2: + case BuiltinOperator_TRANSPOSE: break; - } } return kTfLiteOk; } @@ -866,7 +864,16 @@ TfLiteStatus InterpreterBuilder::ParseTensors( const char* buffer_ptr; TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size)); + bool is_variable = tensor->is_variable(); if (buffer_ptr) { + if (is_variable) { + error_reporter_->Report( + "Tensor %d is a variable tensor with buffer. " + "It's not supported now.\n", + i); + status = kTfLiteError; + } + if (interpreter->SetTensorParametersReadOnly( i, type, get_name(tensor), dims, quantization, buffer_ptr, buffer_size, allocation_) != kTfLiteOk) { @@ -875,8 +882,9 @@ TfLiteStatus InterpreterBuilder::ParseTensors( status = kTfLiteError; } } else { - if (interpreter->SetTensorParametersReadWrite( - i, type, get_name(tensor), dims, quantization) != kTfLiteOk) { + if (interpreter->SetTensorParametersReadWrite(i, type, get_name(tensor), + dims, quantization, + is_variable) != kTfLiteOk) { error_reporter_->Report("Tensor %d is invalidly specified in schema.\n", i); status = kTfLiteError; @@ -960,6 +968,15 @@ TfLiteStatus InterpreterBuilder::operator()( if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk) return cleanup_and_error(); + std::vector variables; + for (int i = 0; i < (*interpreter)->tensors_size(); ++i) { + auto* tensor = (*interpreter)->tensor(i); + if (tensor->is_variable) { + variables.push_back(i); + } + } + (**interpreter).SetVariables(variables); + return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD index f8767b443a2aa64b666c3b6bfb7db30cc0be62ea..f18a2ca07a5f66b760e96a6d9a57db8d6c26b7b9 100644 --- a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD @@ -1,3 +1,5 @@ +load("@build_bazel_rules_android//android:rules.bzl", "android_binary") + package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc index e6c8d966f1aff5a867f9469f8fcdec526df84763..c7e08814fdf502f1ecfea60af3385fc7aa6055fa 100644 --- a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc +++ b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc @@ -35,8 +35,8 @@ const char kModelName[] = "smartreply_ondevice_model.bin"; const char kSamples[] = "smartreply_samples.tsv"; string TestDataPath() { - return string(StrCat(tensorflow::testing::TensorFlowSrcRoot(), "/", - "contrib/lite/models/testdata/")); + return string(absl::StrCat(tensorflow::testing::TensorFlowSrcRoot(), "/", + "contrib/lite/models/testdata/")); } MATCHER_P(IncludeAnyResponesIn, expected_response, "contains the response") { @@ -55,7 +55,7 @@ class PredictorTest : public ::testing::Test { protected: PredictorTest() { model_ = tflite::FlatBufferModel::BuildFromFile( - StrCat(TestDataPath(), "/", kModelName).c_str()); + absl::StrCat(TestDataPath(), "/", kModelName).c_str()); CHECK(model_); } ~PredictorTest() override {} @@ -121,7 +121,7 @@ TEST_F(PredictorTest, BatchTest) { int total_triggers = 0; string line; - std::ifstream fin(StrCat(TestDataPath(), "/", kSamples)); + std::ifstream fin(absl::StrCat(TestDataPath(), "/", kSamples)); while (std::getline(fin, line)) { const std::vector fields = absl::StrSplit(line, '\t'); if (fields.empty()) { diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index d27ab0c033aa0102a3acbdfc0a7b9d724791b19d..ab007993afc35a30179814df23d6c0175f1d955e 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -234,7 +234,10 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, next_id++; }; - auto add_add_params = [&add_scalar_int32]() { add_scalar_int32(0); }; + auto add_add_params = [&add_scalar_int32](void* data) { + auto* builtin = reinterpret_cast(data); + add_scalar_int32(builtin->activation); + }; auto add_pooling_params = [&add_scalar_int32](void* data) { auto builtin = reinterpret_cast(data); @@ -309,7 +312,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, }; auto add_mean_params = [&add_scalar_int32](void* data) { - auto builtin = reinterpret_cast(data); + auto builtin = reinterpret_cast(data); add_scalar_int32(builtin->keep_dims); }; @@ -345,11 +348,11 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, switch (builtin) { case tflite::BuiltinOperator_ADD: nn_op_type = ANEURALNETWORKS_ADD; - add_add_params(); + add_add_params(node.builtin_data); break; case tflite::BuiltinOperator_MUL: nn_op_type = ANEURALNETWORKS_MUL; - add_add_params(); + add_add_params(node.builtin_data); break; case tflite::BuiltinOperator_AVERAGE_POOL_2D: add_pooling_params(node.builtin_data); @@ -490,10 +493,17 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_SELECT: case tflite::BuiltinOperator_SLICE: case tflite::BuiltinOperator_SIN: + case tflite::BuiltinOperator_LOG: case tflite::BuiltinOperator_TRANSPOSE_CONV: case tflite::BuiltinOperator_TILE: case tflite::BuiltinOperator_EXPAND_DIMS: case tflite::BuiltinOperator_SPARSE_TO_DENSE: + case tflite::BuiltinOperator_EQUAL: + case tflite::BuiltinOperator_NOT_EQUAL: + case tflite::BuiltinOperator_SUM: + case tflite::BuiltinOperator_SQRT: + case tflite::BuiltinOperator_RSQRT: + case tflite::BuiltinOperator_SHAPE: FATAL("Op code %d is currently not delegated to NNAPI", builtin); nn_op_type = -1; // set to invalid break; diff --git a/tensorflow/contrib/lite/optional_debug_tools.cc b/tensorflow/contrib/lite/optional_debug_tools.cc index dfdd80ea8a42af683632be1d7e8ab0062847077d..99c35b9cafd82c7dd7ffface33f9c6c59b404c58 100644 --- a/tensorflow/contrib/lite/optional_debug_tools.cc +++ b/tensorflow/contrib/lite/optional_debug_tools.cc @@ -50,6 +50,8 @@ const char* TensorTypeName(TfLiteType type) { return "kTfLiteString"; case kTfLiteBool: return "kTfLiteBool"; + case kTfLiteInt16: + return "kTfLiteInt16"; } return "(invalid)"; } @@ -82,13 +84,13 @@ void PrintInterpreterState(Interpreter* interpreter) { for (int tensor_index = 0; tensor_index < interpreter->tensors_size(); tensor_index++) { TfLiteTensor* tensor = interpreter->tensor(tensor_index); - printf("Tensor %3d %10s %15s %10zu bytes (%4.1f MB) ", tensor_index, - TensorTypeName(tensor->type), AllocTypeName(tensor->allocation_type), - tensor->bytes, float(tensor->bytes) / float(1 << 20)); + printf("Tensor %3d %-20s %10s %15s %10zu bytes (%4.1f MB) ", tensor_index, + tensor->name, TensorTypeName(tensor->type), + AllocTypeName(tensor->allocation_type), tensor->bytes, + (static_cast(tensor->bytes) / (1 << 20))); PrintTfLiteIntVector(tensor->dims); - printf("\n"); } - + printf("\n"); for (int node_index = 0; node_index < interpreter->nodes_size(); node_index++) { const std::pair* node_and_reg = @@ -104,7 +106,4 @@ void PrintInterpreterState(Interpreter* interpreter) { } } -// Prints a dump of what tensors and what nodes are in the interpreter. -TfLiteStatus ValidateInterpreterState(const Interpreter* interpreter); - } // namespace tflite diff --git a/tensorflow/contrib/lite/optional_debug_tools.h b/tensorflow/contrib/lite/optional_debug_tools.h index 1b6998cda382782b974bea3d18ffb6217e8f780c..7fb4b8d8b7ae87cc6e8dd8503c8a4ce0cef2ce8d 100644 --- a/tensorflow/contrib/lite/optional_debug_tools.h +++ b/tensorflow/contrib/lite/optional_debug_tools.h @@ -24,9 +24,6 @@ namespace tflite { // Prints a dump of what tensors and what nodes are in the interpreter. void PrintInterpreterState(Interpreter* interpreter); -// Prints a dump of what tensors and what nodes are in the interpreter. -TfLiteStatus ValidateInterpreterState(const Interpreter* interpreter); - } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_ diff --git a/tensorflow/contrib/lite/profiling/BUILD b/tensorflow/contrib/lite/profiling/BUILD index c31189f2b1f1ad6e3d8e2f5fcae9b6c2ef8eaf23..a162b87b8f98576ec7c3b2623d1d34f2baef6cce 100644 --- a/tensorflow/contrib/lite/profiling/BUILD +++ b/tensorflow/contrib/lite/profiling/BUILD @@ -2,9 +2,11 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") + common_copts = [ "-Wall", -] +] + tflite_copts() cc_library( name = "profiler", @@ -36,12 +38,14 @@ cc_library( name = "time", srcs = ["time.cc"], hdrs = ["time.h"], + copts = common_copts, ) cc_library( name = "profile_summarizer", srcs = ["profile_summarizer.cc"], hdrs = ["profile_summarizer.h"], + copts = common_copts, deps = [ ":profiler", "//tensorflow/contrib/lite:framework", @@ -53,6 +57,7 @@ cc_library( cc_test( name = "profile_summarizer_test", srcs = ["profile_summarizer_test.cc"], + copts = common_copts, deps = [ ":profile_summarizer", "//tensorflow/contrib/lite:framework", diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer.cc b/tensorflow/contrib/lite/profiling/profile_summarizer.cc index 6f2c9cd2b39a1d6be77a10b18658665874067d87..45388b500c7897c8b33b49eb6ab4e9f8c4fdb37c 100644 --- a/tensorflow/contrib/lite/profiling/profile_summarizer.cc +++ b/tensorflow/contrib/lite/profiling/profile_summarizer.cc @@ -85,11 +85,18 @@ OperatorDetails GetOperatorDetails(const tflite::Interpreter& interpreter, return details; } +tensorflow::StatSummarizerOptions GetProfileSummarizerOptions() { + auto options = tensorflow::StatSummarizerOptions(); + options.show_summary = true; + options.show_memory = false; + return options; +} + } // namespace ProfileSummarizer::ProfileSummarizer() - : stats_calculator_(new ::tensorflow::StatsCalculator( - tensorflow::StatSummarizerOptions())) {} + : stats_calculator_( + new ::tensorflow::StatsCalculator(GetProfileSummarizerOptions())) {} void ProfileSummarizer::ProcessProfiles( const std::vector& profile_stats, diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index 7e6ff6c0a8314e71a64f27916a6189f229b81ab4..27909a9458f6b09f96cb556a5254f01e54f46e05 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -57,8 +57,9 @@ py_library( ":interpreter", ":lite_constants", ":op_hint", - "//tensorflow/contrib/saved_model:saved_model_py", "//tensorflow/python:graph_util", + "//tensorflow/python/saved_model:constants", + "//tensorflow/python/saved_model:loader", "//tensorflow/python/tools:freeze_graph_lib", ], ) diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py index 0819475240d03a6948ce4bf798e932fa938f46a8..c038c88945b71f30bf091a1098dcf853f5415b1b 100644 --- a/tensorflow/contrib/lite/python/convert.py +++ b/tensorflow/contrib/lite/python/convert.py @@ -111,26 +111,27 @@ def tensor_name(x): return x.name.split(":")[0] -def toco_convert(input_data, - input_tensors, - output_tensors, - inference_type=lite_constants.FLOAT, - inference_input_type=None, - input_format=lite_constants.TENSORFLOW_GRAPHDEF, - output_format=lite_constants.TFLITE, - quantized_input_stats=None, - default_ranges_stats=None, - drop_control_dependency=True, - reorder_across_fake_quant=False, - allow_custom_ops=False, - change_concat_input_ranges=False): - """Convert a model using TOCO from `input_format` to `output_format`. +def build_toco_convert_protos(input_tensors, + output_tensors, + inference_type=lite_constants.FLOAT, + inference_input_type=None, + input_format=lite_constants.TENSORFLOW_GRAPHDEF, + output_format=lite_constants.TFLITE, + quantized_input_stats=None, + default_ranges_stats=None, + drop_control_dependency=True, + reorder_across_fake_quant=False, + allow_custom_ops=False, + change_concat_input_ranges=False, + quantize_weights=False, + dump_graphviz_dir=None, + dump_graphviz_video=False): + """Builds protocol buffers describing a conversion of a model using TOCO. Typically this is to convert from TensorFlow GraphDef to TFLite, in which case the default `input_format` and `output_format` are sufficient. Args: - input_data: Input data (i.e. often `sess.graph_def`). input_tensors: List of input tensors. Type and shape are computed using `foo.get_shape()` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). @@ -143,10 +144,9 @@ def toco_convert(input_data, `{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF) output_format: Output file format. Currently must be `{TFLITE, GRAPHVIZ_DOT}`. (default TFLITE) - quantized_input_stats: Dict of strings representing input tensor names - mapped to tuple of integers representing the mean and standard deviation - of the training data (e.g., {"foo" : (0., 1.)}). Only need if - `inference_type` is `QUANTIZED_UINT8`. (default None) + quantized_input_stats: List of tuples of integers representing the mean and + standard deviation. Each tuple maps to the corresponding input tensor. + Only need if `inference_type` is `QUANTIZED_UINT8`. (default None) default_ranges_stats: Tuple of integers representing (min, max) range values for all arrays without a specified range. Intended for experimenting with quantization via "dummy quantization". (default None) @@ -158,18 +158,28 @@ def toco_convert(input_data, nodes is preventing graph transformations necessary to convert the graph. Results in a graph that differs from the quantized training graph, potentially causing differing arithmetic behavior. (default False) - change_concat_input_ranges: Boolean to change behavior of min/max ranges for - inputs and outputs of the concat operator for quantized models. Changes - the ranges of concat operator overlap when true. (default False) allow_custom_ops: Boolean indicating whether to allow custom operations. When false any unknown operation is an error. When true, custom ops are created for any op that is unknown. The developer will need to provide these to the TensorFlow Lite runtime with a custom resolver. (default False) + change_concat_input_ranges: Boolean to change behavior of min/max ranges for + inputs and outputs of the concat operator for quantized models. Changes + the ranges of concat operator overlap when true. (default False) + quantize_weights: Boolean indicating whether to store weights as quantized + weights followed by dequantize operations. Computation is still done in + float, but reduces model size (at the cost of accuracy and latency). + (default False) + dump_graphviz_dir: Full filepath of folder to dump the graphs at various + stages of processing GraphViz .dot files. Preferred over + --output_format=GRAPHVIZ_DOT in order to keep the requirements of the + output file. (default None) + dump_graphviz_video: Boolean indicating whether to dump the graph after + every graph transformation. (default False) Returns: - The converted data. For example if TFLite was the destination, then - this will be a tflite flatbuffer in a bytes array. + model_flags, toco_flags: two protocol buffers describing the conversion + process. Raises: ValueError: If the input tensor type is unknown @@ -185,10 +195,13 @@ def toco_convert(input_data, toco.drop_control_dependency = drop_control_dependency toco.reorder_across_fake_quant = reorder_across_fake_quant toco.allow_custom_ops = allow_custom_ops + toco.quantize_weights = quantize_weights if default_ranges_stats: toco.default_ranges_min = default_ranges_stats[0] toco.default_ranges_max = default_ranges_stats[1] - + if dump_graphviz_dir: + toco.dump_graphviz_dir = dump_graphviz_dir + toco.dump_graphviz_include_video = dump_graphviz_video model = _model_flags_pb2.ModelFlags() model.change_concat_input_ranges = change_concat_input_ranges for idx, input_tensor in enumerate(input_tensors): @@ -217,10 +230,35 @@ def toco_convert(input_data, for output_tensor in output_tensors: model.output_arrays.append(tensor_name(output_tensor)) + return model, toco + + +def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs): + """"Convert a model using TOCO. - # TODO(aselle): Consider handling the case of allowing quantized - # inputs to be converted to float (via the toco.inference_input_type field). - data = toco_convert_protos(model.SerializeToString(), - toco.SerializeToString(), + Typically this function is used to convert from TensorFlow GraphDef to TFLite. + Conversion can be customized by providing arguments that are forwarded to + `build_toco_convert_protos` (see documentation for details). + + Args: + input_data: Input data (i.e. often `sess.graph_def`), + input_tensors: List of input tensors. Type and shape are computed using + `foo.get_shape()` and `foo.dtype`. + output_tensors: List of output tensors (only .name is used from this). + *args: See `build_toco_convert_protos`, + **kwargs: See `build_toco_convert_protos`. + + Returns: + The converted data. For example if TFLite was the destination, then + this will be a tflite flatbuffer in a bytes array. + + Raises: + Defined in `build_toco_convert_protos`. + """ + model_flags, toco_flags = build_toco_convert_protos(input_tensors, + output_tensors, + *args, **kwargs) + data = toco_convert_protos(model_flags.SerializeToString(), + toco_flags.SerializeToString(), input_data.SerializeToString()) return data diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py index 5dad49f1ed29f3bd57b1b120808ef645adee760c..1553464b9fe30f596c151bcc67efe891bb913ba3 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model.py +++ b/tensorflow/contrib/lite/python/convert_saved_model.py @@ -19,13 +19,12 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.lite.python.convert import tensor_name -from tensorflow.contrib.saved_model.python.saved_model import reader -from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils from tensorflow.core.framework import types_pb2 from tensorflow.python.client import session from tensorflow.python.framework import graph_util as tf_graph_util from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import loader @@ -58,21 +57,8 @@ def _get_meta_graph_def(saved_model_dir, tag_set): Raises: ValueError: No valid MetaGraphDef for given tag_set. """ - saved_model = reader.read_saved_model(saved_model_dir) - tag_sets = [] - result_meta_graph_def = None - for meta_graph_def in saved_model.meta_graphs: - meta_graph_tag_set = set(meta_graph_def.meta_info_def.tags) - tag_sets.append(meta_graph_tag_set) - if meta_graph_tag_set == tag_set: - result_meta_graph_def = meta_graph_def - logging.info("The given saved_model contains the following tags: %s", - tag_sets) - if result_meta_graph_def is not None: - return result_meta_graph_def - else: - raise ValueError("No valid MetaGraphDef for this tag_set '{}'. Possible " - "values are '{}'. ".format(tag_set, tag_sets)) + with session.Session(graph=ops.Graph()) as sess: + return loader.load(sess, tag_set, saved_model_dir) def _get_signature_def(meta_graph, signature_key): @@ -97,9 +83,7 @@ def _get_signature_def(meta_graph, signature_key): raise ValueError("No '{}' in the SavedModel\'s SignatureDefs. Possible " "values are '{}'.".format(signature_key, ",".join(signature_def_keys))) - signature_def = signature_def_utils.get_signature_def_by_key( - meta_graph, signature_key) - return signature_def + return signature_def_map[signature_key] def _get_inputs_outputs(signature_def): @@ -247,6 +231,7 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes, ValueError: SavedModel doesn't contain a MetaGraphDef identified by tag_set. signature_key is not in the MetaGraphDef. + assets/ directory is in the MetaGraphDef. input_shapes does not match the length of input_arrays. input_arrays or output_arrays are not valid. """ @@ -255,9 +240,13 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes, signature_def = _get_signature_def(meta_graph, signature_key) inputs, outputs = _get_inputs_outputs(signature_def) + # Check SavedModel for assets directory. + collection_def = meta_graph.collection_def + if constants.ASSETS_KEY in collection_def: + raise ValueError("SavedModels with assets/ directory are not supported.") + graph = ops.Graph() with session.Session(graph=graph) as sess: - # TODO(nupurgarg): Throw ValueError if SavedModel has assets/ directory. loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir) # Gets input and output tensors. diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py index 779bda4c9d05fd056d6a262412fdcf0d47e7c57c..fd908234254185e0a0639618e936ca8ff58631da 100644 --- a/tensorflow/contrib/lite/python/interpreter.py +++ b/tensorflow/contrib/lite/python/interpreter.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys from tensorflow.python.util.lazy_loader import LazyLoader # Lazy load since some of the performance benchmark skylark rules @@ -64,9 +65,38 @@ class Interpreter(object): raise ValueError('Can\'t both provide `model_path` and `model_content`') def allocate_tensors(self): + self._ensure_safe() if not self._interpreter.AllocateTensors(): raise ValueError('Failed to allocate tensors') + def _safe_to_run(self): + """Returns true if there exist no numpy array buffers. + + This means it is safe to run tflite calls that may destroy internally + allocated memory. This works, because in the wrapper.cc we have made + the numpy base be the self._interpreter. + """ + # NOTE, our tensor() call in cpp will use _interpreter as a base pointer. + # If this environment is the only _interpreter, then the ref count should be + # 2 (1 in self and 1 in temporary of sys.getrefcount). + return sys.getrefcount(self._interpreter) == 2 + + def _ensure_safe(self): + """Makes sure no numpy arrays pointing to internal buffers are active. + + This should be called from any function that will call a function on + _interpreter that may reallocate memory e.g. invoke(), ... + + Raises: + RuntimeError: If there exist numpy objects pointing to internal memory + then we throw. + """ + if not self._safe_to_run(): + raise RuntimeError("""There is at least 1 reference to internal data + in the interpreter in the form of a numpy array or slice. Be sure to + only hold the function returned from tensor() if you are using raw + data access.""") + def _get_tensor_details(self, tensor_index): """Gets tensor details. @@ -109,7 +139,10 @@ class Interpreter(object): ] def set_tensor(self, tensor_index, value): - """Sets the value of the input. + """Sets the value of the input tensor. Note this copies data in `value`. + + If you want to avoid copying, you can use the `tensor()` function to get a + numpy buffer pointing to the input buffer in the tflite interpreter. Args: tensor_index: Tensor index of tensor to set. This value can be gotten from @@ -133,6 +166,7 @@ class Interpreter(object): Raises: ValueError: If the interpreter could not resize the input tensor. """ + self._ensure_safe() if not self._interpreter.ResizeInputTensor(input_index, tensor_size): raise ValueError('Failed to resize input') @@ -147,7 +181,7 @@ class Interpreter(object): ] def get_tensor(self, tensor_index): - """Sets the value of the input. + """Gets the value of the input tensor. Note this makes a copy so prefer `tensor()`. Args: tensor_index: Tensor index of tensor to get. This value can be gotten from @@ -158,6 +192,60 @@ class Interpreter(object): """ return self._interpreter.GetTensor(tensor_index) + def tensor(self, tensor_index): + """Returns function that gives a numpy view of the current tensor buffer. + + This allows reading and writing to this tensors w/o copies. This more + closely mirrors the C++ Interpreter class interface's tensor() member, hence + the name. Be careful to not hold these output references through calls + to `allocate_tensors()` and `invoke()`. + + Usage: + + interpreter.allocate_tensors() + input = interpreter.tensor(interpreter.get_input_details()[0]["index"]) + output = interpreter.tensor(interpreter.get_output_details()[0]["index"]) + for i in range(10): + input().fill(3.) + interpreter.invoke() + print("inference %s" % output) + + Notice how this function avoids making a numpy array directly. This is + because it is important to not hold actual numpy views to the data longer + than necessary. If you do, then the interpreter can no longer be invoked, + because it is possible the interpreter would resize and invalidate the + referenced tensors. The NumPy API doesn't allow any mutability of the + the underlying buffers. + + WRONG: + + input = interpreter.tensor(interpreter.get_input_details()[0]["index"])() + output = interpreter.tensor(interpreter.get_output_details()[0]["index"])() + interpreter.allocate_tensors() # This will throw RuntimeError + for i in range(10): + input.fill(3.) + interpreter.invoke() # this will throw RuntimeError since input,output + + Args: + tensor_index: Tensor index of tensor to get. This value can be gotten from + the 'index' field in get_output_details. + + Returns: + A function that can return a new numpy array pointing to the internal + TFLite tensor state at any point. It is safe to hold the function forever, + but it is not safe to hold the numpy array forever. + """ + return lambda: self._interpreter.tensor(self._interpreter, tensor_index) + def invoke(self): + """Invoke the interpreter. + + Be sure to set the input sizes, allocate tensors and fill values before + calling this. + + Raises: + ValueError: When the underlying interpreter fails raise ValueError. + """ + self._ensure_safe() if not self._interpreter.Invoke(): raise ValueError('Failed to invoke TFLite model') diff --git a/tensorflow/contrib/lite/python/interpreter_test.py b/tensorflow/contrib/lite/python/interpreter_test.py index f802edf020db8a9d4e7bb890aadaae7e34e983a8..5f1fa26c3b7f76309a6f1f80aa3c1e4889781528 100644 --- a/tensorflow/contrib/lite/python/interpreter_test.py +++ b/tensorflow/contrib/lite/python/interpreter_test.py @@ -91,5 +91,61 @@ class InterpreterTest(test_util.TensorFlowTestCase): self.assertTrue((expected_output == output_data).all()) +class InterpreterTensorAccessorTest(test_util.TensorFlowTestCase): + + def setUp(self): + self.interpreter = interpreter_wrapper.Interpreter( + model_path=resource_loader.get_path_to_datafile( + 'testdata/permute_float.tflite')) + self.interpreter.allocate_tensors() + self.input0 = self.interpreter.get_input_details()[0]['index'] + self.initial_data = np.array([[-1., -2., -3., -4.]], np.float32) + + def testTensorAccessor(self): + """Check that tensor returns a reference.""" + array_ref = self.interpreter.tensor(self.input0) + np.copyto(array_ref(), self.initial_data) + self.assertAllEqual(array_ref(), self.initial_data) + self.assertAllEqual( + self.interpreter.get_tensor(self.input0), self.initial_data) + + def testGetTensorAccessor(self): + """Check that get_tensor returns a copy.""" + self.interpreter.set_tensor(self.input0, self.initial_data) + array_initial_copy = self.interpreter.get_tensor(self.input0) + new_value = np.add(1., array_initial_copy) + self.interpreter.set_tensor(self.input0, new_value) + self.assertAllEqual(array_initial_copy, self.initial_data) + self.assertAllEqual(self.interpreter.get_tensor(self.input0), new_value) + + def testBase(self): + self.assertTrue(self.interpreter._safe_to_run()) + _ = self.interpreter.tensor(self.input0) + self.assertTrue(self.interpreter._safe_to_run()) + in0 = self.interpreter.tensor(self.input0)() + self.assertFalse(self.interpreter._safe_to_run()) + in0b = self.interpreter.tensor(self.input0)() + self.assertFalse(self.interpreter._safe_to_run()) + # Now get rid of the buffers so that we can evaluate. + del in0 + del in0b + self.assertTrue(self.interpreter._safe_to_run()) + + def testBaseProtectsFunctions(self): + in0 = self.interpreter.tensor(self.input0)() + # Make sure we get an exception if we try to run an unsafe operation + with self.assertRaisesRegexp( + RuntimeError, 'There is at least 1 reference'): + _ = self.interpreter.allocate_tensors() + # Make sure we get an exception if we try to run an unsafe operation + with self.assertRaisesRegexp( + RuntimeError, 'There is at least 1 reference'): + _ = self.interpreter.invoke() + # Now test that we can run + del in0 # this is our only buffer reference, so now it is safe to change + in0safe = self.interpreter.tensor(self.input0) + _ = self.interpreter.allocate_tensors() + del in0safe # make sure in0Safe is held but lint doesn't complain + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD index 12ab38847dc0f838ae2c6bf80ed80805285e4b8b..634c2a1e1f5005208b4eea5c853a43cccf4d244c 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD @@ -14,7 +14,7 @@ cc_library( "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite/kernels:builtin_ops", "//tensorflow/core:lib", - "//tensorflow/python:numpy_lib", + "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", "@com_google_absl//absl/memory", ], diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc index 5f304ad45d400b13e20bda8184b5b40cfe13f6c2..b283551c45d3d75aecb50043f1c7486b3345118d 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -21,7 +21,14 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/python/lib/core/numpy.h" + +// Disallow Numpy 1.7 deprecated symbols. +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +#include + +#include "numpy/arrayobject.h" +#include "numpy/ufuncobject.h" #if PY_MAJOR_VERSION >= 3 #define PY_TO_CPPSTRING PyBytes_AsStringAndSize @@ -35,6 +42,13 @@ namespace tflite { namespace interpreter_wrapper { namespace { + +// Calls PyArray's initialization to initialize all the API pointers. Note that +// this usage implies only this translation unit can use the pointers. See +// tensorflow/python/core/numpy.cc for a strategy if we ever need to extend +// this further. +void ImportNumpy() { import_array1(); } + std::unique_ptr CreateInterpreter( const tflite::FlatBufferModel* model, const tflite::ops::builtin::BuiltinOpResolver& resolver) { @@ -42,7 +56,7 @@ std::unique_ptr CreateInterpreter( return nullptr; } - tensorflow::ImportNumpy(); + ImportNumpy(); std::unique_ptr interpreter; tflite::InterpreterBuilder(*model, resolver)(&interpreter); @@ -68,6 +82,8 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { return NPY_FLOAT32; case kTfLiteInt32: return NPY_INT32; + case kTfLiteInt16: + return NPY_INT16; case kTfLiteUInt8: return NPY_UINT8; case kTfLiteInt64: @@ -90,6 +106,8 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { return kTfLiteFloat32; case NPY_INT32: return kTfLiteInt32; + case NPY_INT16: + return kTfLiteInt16; case NPY_UINT8: return kTfLiteUInt8; case NPY_INT64: @@ -284,47 +302,93 @@ bool InterpreterWrapper::SetTensor(int i, PyObject* value) { return true; } -PyObject* InterpreterWrapper::GetTensor(int i) const { - if (!interpreter_) { +namespace { + +PyObject* CheckGetTensorArgs(Interpreter* interpreter, int tensor_index, + TfLiteTensor** tensor, int* type_num) { + if (!interpreter) { LOG(ERROR) << "Invalid interpreter."; Py_INCREF(Py_None); return Py_None; } - if (i >= interpreter_->tensors_size()) { - LOG(ERROR) << "Invalid tensor index: " << i << " exceeds max tensor index " - << interpreter_->inputs().size(); + if (tensor_index >= interpreter->tensors_size() || tensor_index < 0) { + LOG(ERROR) << "Invalid tensor index: " << tensor_index + << " exceeds max tensor index " << interpreter->inputs().size(); Py_INCREF(Py_None); return Py_None; } - const TfLiteTensor* output_tensor = interpreter_->tensor(i); - const int tensor_size = output_tensor->bytes; - if (tensor_size <= 0) { + *tensor = interpreter->tensor(tensor_index); + if ((*tensor)->bytes == 0) { LOG(ERROR) << "Invalid tensor size"; Py_INCREF(Py_None); return Py_None; } - int type_num = TfLiteTypeToPyArrayType(output_tensor->type); - if (type_num == -1) { - LOG(ERROR) << "Unknown tensor type " << output_tensor->type; + *type_num = TfLiteTypeToPyArrayType((*tensor)->type); + if (*type_num == -1) { + LOG(ERROR) << "Unknown tensor type " << (*tensor)->type; + Py_INCREF(Py_None); + return Py_None; + } + + if (!(*tensor)->data.raw) { + LOG(ERROR) << "Tensor data is null."; Py_INCREF(Py_None); return Py_None; } - void* data = malloc(tensor_size); - memcpy(data, output_tensor->data.raw, tensor_size); + return nullptr; +} + +} // namespace - const TfLiteIntArray* output_dims = output_tensor->dims; - std::vector dims(output_dims->data, - output_dims->data + output_dims->size); +PyObject* InterpreterWrapper::GetTensor(int i) const { + // Sanity check accessor + TfLiteTensor* tensor = nullptr; + int type_num = 0; + if (PyObject* pynone_or_nullptr = + CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) { + return pynone_or_nullptr; + } + std::vector dims(tensor->dims->data, + tensor->dims->data + tensor->dims->size); + // Make a buffer copy but we must tell Numpy It owns that data or else + // it will leak. + void* data = malloc(tensor->bytes); + if (!data) { + LOG(ERROR) << "Malloc to copy tensor failed."; + Py_INCREF(Py_None); + return Py_None; + } + memcpy(data, tensor->data.raw, tensor->bytes); PyObject* np_array = PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data); - + PyArray_ENABLEFLAGS(reinterpret_cast(np_array), + NPY_ARRAY_OWNDATA); return PyArray_Return(reinterpret_cast(np_array)); } +PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) { + // Sanity check accessor + TfLiteTensor* tensor = nullptr; + int type_num = 0; + if (PyObject* pynone_or_nullptr = + CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) { + return pynone_or_nullptr; + } + + std::vector dims(tensor->dims->data, + tensor->dims->data + tensor->dims->size); + PyArrayObject* np_array = + reinterpret_cast(PyArray_SimpleNewFromData( + dims.size(), dims.data(), type_num, tensor->data.raw)); + Py_INCREF(base_object); // SetBaseObject steals, so we need to add. + PyArray_SetBaseObject(np_array, base_object); + return PyArray_Return(np_array); +} + InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile( const char* model_path) { std::unique_ptr model = diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h index 01320af7a9ea3a652020e2b42300da6081ff68e5..e7343cb388d657e472464f69fa8cd0c6ddc60923 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -19,6 +19,8 @@ limitations under the License. #include #include +// Place `` before to avoid build failures in macOS. +#include #include // We forward declare TFLite classes here to avoid exposing them to SWIG. @@ -56,6 +58,9 @@ class InterpreterWrapper { PyObject* TensorQuantization(int i) const; bool SetTensor(int i, PyObject* value); PyObject* GetTensor(int i) const; + // Returns a reference to tensor index i as a numpy array. The base_object + // should be the interpreter object providing the memory. + PyObject* tensor(PyObject* base_object, int i); private: InterpreterWrapper(std::unique_ptr model); diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 72605643185301e53b9acce5b8352d36f23fda18..88dda7290b10b653cf0b3e01a97729e82d1e5259 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -22,6 +22,7 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice. @@Interpreter @@OpHint @@convert_op_hints_to_stubs +@@build_toco_convert_protos @@FLOAT @@QUANTIZED_UINT8 @@ -38,6 +39,7 @@ from six import PY3 from google.protobuf import text_format as _text_format from google.protobuf.message import DecodeError from tensorflow.contrib.lite.python import lite_constants as constants +from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import from tensorflow.contrib.lite.python.convert import tensor_name from tensorflow.contrib.lite.python.convert import toco_convert from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import @@ -54,6 +56,7 @@ from tensorflow.python.framework.importer import import_graph_def from tensorflow.python.ops.variables import global_variables_initializer from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import tag_constants +# from tensorflow.python.util.all_util import remove_undocumented class TocoConverter(object): @@ -94,6 +97,16 @@ class TocoConverter(object): created for any op that is unknown. The developer will need to provide these to the TensorFlow Lite runtime with a custom resolver. (default False) + quantize_weights: Boolean indicating whether to store weights as quantized + weights followed by dequantize operations. Computation is still done in + float, but reduces model size (at the cost of accuracy and latency). + (default False) + dump_graphviz_dir: Full filepath of folder to dump the graphs at various + stages of processing GraphViz .dot files. Preferred over + --output_format=GRAPHVIZ_DOT in order to keep the requirements of the + output file. (default None) + dump_graphviz_video: Boolean indicating whether to dump the graph after + every graph transformation. (default False) Example usage: @@ -135,6 +148,9 @@ class TocoConverter(object): self.reorder_across_fake_quant = False self.change_concat_input_ranges = False self.allow_custom_ops = False + self.quantize_weights = False + self.dump_graphviz_dir = None + self.dump_graphviz_video = False @classmethod def from_session(cls, sess, input_tensors, output_tensors): @@ -210,7 +226,7 @@ class TocoConverter(object): # Check if graph is frozen. if not _is_frozen_graph(sess): - raise ValueError("Please freeze the graph using freeze_graph.py") + raise ValueError("Please freeze the graph using freeze_graph.py.") # Create TocoConverter class. return cls(sess.graph_def, input_tensors, output_tensors) @@ -310,9 +326,20 @@ class TocoConverter(object): drop_control_dependency=self.drop_control_dependency, reorder_across_fake_quant=self.reorder_across_fake_quant, change_concat_input_ranges=self.change_concat_input_ranges, - allow_custom_ops=self.allow_custom_ops) + allow_custom_ops=self.allow_custom_ops, + quantize_weights=self.quantize_weights, + dump_graphviz_dir=self.dump_graphviz_dir, + dump_graphviz_video=self.dump_graphviz_video) return result + def get_input_arrays(self): + """Returns a list of the names of the input tensors. + + Returns: + List of strings. + """ + return [tensor_name(tensor) for tensor in self._input_tensors] + def _set_batch_size(self, batch_size): """Sets the first dimension of the input tensor to `batch_size`. @@ -364,3 +391,5 @@ def _freeze_graph(sess, output_tensors): output_arrays) else: return sess.graph_def + +# remove_undocumented(__name__) diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index 019a3a5f69b35f9f69f835471a45c58f3d82bb3b..a9475de47408d7d451663cdf021d40eaa85c7c63 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -25,9 +25,11 @@ from tensorflow.contrib.lite.python import lite from tensorflow.contrib.lite.python import lite_constants from tensorflow.contrib.lite.python.interpreter import Interpreter from tensorflow.python.client import session +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import gfile from tensorflow.python.platform import test @@ -218,6 +220,7 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) self.assertEqual((0., 0.), output_details[0]['quantization']) + # TODO(nupurgarg): Verify value of contents in GraphViz. def testGraphviz(self): in_tensor = array_ops.placeholder( shape=[1, 16, 16, 3], dtype=dtypes.float32) @@ -230,8 +233,42 @@ class FromSessionTest(test_util.TensorFlowTestCase): graphviz_output = converter.convert() self.assertTrue(graphviz_output) + # TODO(nupurgarg): Verify value of contents in GraphViz. + def testDumpGraphviz(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + graphviz_dir = self.get_temp_dir() + converter.dump_graphviz_dir = graphviz_dir + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Ensure interpreter is able to allocate and check graphviz data. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + num_items_graphviz = len(os.listdir(graphviz_dir)) + self.assertTrue(num_items_graphviz) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + graphviz_dir = self.get_temp_dir() + converter.dump_graphviz_dir = graphviz_dir + converter.dump_graphviz_video = True + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Ensure graphviz folder has more data after using video flag. + num_items_graphviz_video = len(os.listdir(graphviz_dir)) + self.assertTrue(num_items_graphviz_video > num_items_graphviz) + def testInferenceInputType(self): - in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.uint8) + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) out_tensor = in_tensor + in_tensor sess = session.Session() @@ -250,14 +287,13 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertEqual('Placeholder', input_details[0]['name']) self.assertEqual(np.uint8, input_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) - self.assertEqual((0., 0.), input_details[0]['quantization']) + self.assertEqual((1., 0.), input_details[0]['quantization']) output_details = interpreter.get_output_details() self.assertEqual(1, len(output_details)) self.assertEqual('add', output_details[0]['name']) - self.assertEqual(np.uint8, output_details[0]['dtype']) + self.assertEqual(np.float32, output_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) - self.assertEqual((0., 0.), input_details[0]['quantization']) def testDefaultRangesStats(self): in_tensor = array_ops.placeholder( @@ -291,6 +327,36 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) self.assertTrue(output_details[0]['quantization'][0] > 0) # scale + def testQuantizeWeights(self): + np.random.seed(0) + # We need the tensor to have more than 1024 elements for quantize_weights + # to kick in. Thus, the [33, 33] shape. + in_tensor_1 = array_ops.placeholder( + shape=[33, 33], dtype=dtypes.float32, name='inputA') + in_tensor_2 = constant_op.constant( + np.random.uniform(low=-10., high=10., size=(33, 33)), + shape=[33, 33], + dtype=dtypes.float32, + name='inputB') + out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output') + sess = session.Session() + + # Convert float model. + float_converter = lite.TocoConverter.from_session(sess, [in_tensor_1], + [out_tensor]) + float_tflite = float_converter.convert() + self.assertTrue(float_tflite) + + # Convert quantized weights model. + quantized_weights_converter = lite.TocoConverter.from_session( + sess, [in_tensor_1], [out_tensor]) + quantized_weights_converter.quantize_weights = True + quantized_weights_tflite = quantized_weights_converter.convert() + self.assertTrue(quantized_weights_tflite) + + # Ensure that the quantized weights tflite model is smaller. + self.assertTrue(len(quantized_weights_tflite) < len(float_tflite)) + class FromFrozenGraphFile(test_util.TensorFlowTestCase): @@ -369,7 +435,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase): with self.assertRaises(ValueError) as error: lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'], ['add']) - self.assertEqual('Please freeze the graph using freeze_graph.py', + self.assertEqual('Please freeze the graph using freeze_graph.py.', str(error.exception)) def testPbtxt(self): diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py index 337f05785eed6c452a43dc1b2118389ff8106714..f497533bed054d260aefc7b3fe67ae655c7cbcda 100644 --- a/tensorflow/contrib/lite/python/tflite_convert.py +++ b/tensorflow/contrib/lite/python/tflite_convert.py @@ -86,6 +86,9 @@ def _convert_model(flags): Args: flags: argparse.Namespace object. + + Raises: + ValueError: Invalid flags. """ # Create converter. converter = _get_toco_converter(flags) @@ -99,12 +102,22 @@ def _convert_model(flags): flags.output_format) if flags.mean_values and flags.std_dev_values: - input_arrays = _parse_array(flags.input_arrays) + input_arrays = converter.get_input_arrays() std_dev_values = _parse_int_array(flags.std_dev_values) mean_values = _parse_int_array(flags.mean_values) quant_stats = zip(mean_values, std_dev_values) + if ((not flags.input_arrays and len(input_arrays) > 1) or + (len(input_arrays) != len(quant_stats))): + raise ValueError("Mismatching --input_arrays, --std_dev_values, and " + "--mean_values. The flags must have the same number of " + "items. The current input arrays are '{0}'. " + "--input_arrays must be present when specifying " + "--std_dev_values and --mean_values with multiple input " + "tensors in order to map between names and " + "values.".format(",".join(input_arrays))) converter.quantized_input_stats = dict(zip(input_arrays, quant_stats)) - if flags.default_ranges_min and flags.default_ranges_max: + if (flags.default_ranges_min is not None) and (flags.default_ranges_max is + not None): converter.default_ranges_stats = (flags.default_ranges_min, flags.default_ranges_max) @@ -116,6 +129,12 @@ def _convert_model(flags): converter.change_concat_input_ranges = flags.change_concat_input_ranges if flags.allow_custom_ops: converter.allow_custom_ops = flags.allow_custom_ops + if flags.quantize_weights: + converter.quantize_weights = flags.quantize_weights + if flags.dump_graphviz_dir: + converter.dump_graphviz_dir = flags.dump_graphviz_dir + if flags.dump_graphviz_video: + converter.dump_graphviz_vode = flags.dump_graphviz_video # Convert model. output_data = converter.convert() @@ -147,9 +166,14 @@ def _check_flags(flags, unparsed): output = "" for flag in unparsed: output += _get_message_unparsed(flag, "--input_file", "--graph_def_file") + output += _get_message_unparsed(flag, "--savedmodel_directory", + "--saved_model_dir") output += _get_message_unparsed(flag, "--std_value", "--std_dev_values") output += _get_message_unparsed(flag, "--batch_size", "--input_shapes") - raise ValueError(output) + output += _get_message_unparsed(flag, "--dump_graphviz", + "--dump_graphviz_dir") + if output: + raise ValueError(output) # Check that flags are valid. if flags.graph_def_file and (not flags.input_arrays or @@ -168,15 +192,11 @@ def _check_flags(flags, unparsed): if bool(flags.std_dev_values) != bool(flags.mean_values): raise ValueError("--std_dev_values and --mean_values must be used " "together") - if not flags.input_arrays: - raise ValueError("--std_dev_values and --mean_values must be used with " - "--input_arrays") - if (flags.std_dev_values.count(",") != flags.mean_values.count(",") or - flags.std_dev_values.count(",") != flags.input_arrays.count(",")): - raise ValueError("--std_dev_values, --mean_values, and --input_arrays " - "must have the same number of items") - - if bool(flags.default_ranges_min) != bool(flags.default_ranges_max): + if flags.std_dev_values.count(",") != flags.mean_values.count(","): + raise ValueError("--std_dev_values, --mean_values must have the same " + "number of items") + + if (flags.default_ranges_min is None) != (flags.default_ranges_max is None): raise ValueError("--default_ranges_min and --default_ranges_max must be " "used together") @@ -208,17 +228,17 @@ def run_main(_): # Model format flags. parser.add_argument( "--output_format", - type=str, + type=str.upper, choices=["TFLITE", "GRAPHVIZ_DOT"], help="Output file format.") parser.add_argument( "--inference_type", - type=str, + type=str.upper, choices=["FLOAT", "QUANTIZED_UINT8"], help="Target data type of arrays in the output file.") parser.add_argument( "--inference_input_type", - type=str, + type=str.upper, choices=["FLOAT", "QUANTIZED_UINT8"], help=("Target data type of input arrays. Allows for a different type for " "input arrays in the case of quantization.")) @@ -273,17 +293,23 @@ def run_main(_): help=("Default value for max bound of min/max range values used for all " "arrays without a specified range, Intended for experimenting with " "quantization via \"dummy quantization\". (default None)")) + parser.add_argument( + "--quantize_weights", + type=bool, + help=("Store float weights as quantized weights followed by dequantize " + "operations. Inference is still done in FLOAT, but reduces model " + "size (at the cost of accuracy and latency).")) # Graph manipulation flags. parser.add_argument( "--drop_control_dependency", - type=bool, + action="store_true", help=("Boolean indicating whether to drop control dependencies silently. " "This is due to TensorFlow not supporting control dependencies. " "(default True)")) parser.add_argument( "--reorder_across_fake_quant", - type=bool, + action="store_true", help=("Boolean indicating whether to reorder FakeQuant nodes in " "unexpected locations. Used when the location of the FakeQuant " "nodes is preventing graph transformations necessary to convert " @@ -292,19 +318,33 @@ def run_main(_): "behavior. (default False)")) parser.add_argument( "--change_concat_input_ranges", - type=bool, + action="store_true", help=("Boolean to change behavior of min/max ranges for inputs and " "outputs of the concat operator for quantized models. Changes the " "ranges of concat operator overlap when true. (default False)")) parser.add_argument( "--allow_custom_ops", - type=bool, + action="store_true", help=("Boolean indicating whether to allow custom operations. When false " "any unknown operation is an error. When true, custom ops are " "created for any op that is unknown. The developer will need to " "provide these to the TensorFlow Lite runtime with a custom " "resolver. (default False)")) + # Logging flags. + parser.add_argument( + "--dump_graphviz_dir", + type=str, + help=("Full filepath of folder to dump the graphs at various stages of " + "processing GraphViz .dot files. Preferred over --output_format=" + "GRAPHVIZ_DOT in order to keep the requirements of the output " + "file.")) + parser.add_argument( + "--dump_graphviz_video", + action="store_true", + help=("Boolean indicating whether to dump the graph after every graph " + "transformation")) + tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:]) try: _check_flags(tflite_flags, unparsed) diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc b/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc index 64ab0a9fe2f01a732af91ed4052e44cf8c38f89b..9dc8daa227dd68ccde2efa4013ac4465a72e6bb0 100644 --- a/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc +++ b/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc @@ -39,7 +39,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ // DO NOT EDIT MANUALLY: This file is automatically generated by -// `schema_builtin_ops_header_generator.py`. +// `schema/builtin_ops_header/generator.cc`. #ifdef __cplusplus extern "C" { diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 7dbb36c864a84b7ba848bfb8a48e09023726f4a3..df43f1e5abf921410b14912e33562bf1a7067795 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -34,6 +34,7 @@ enum TensorType : byte { INT64 = 4, STRING = 5, BOOL = 6, + INT16 = 7, } // Parameters for converting a quantized tensor back to float. Given a @@ -63,6 +64,8 @@ table Tensor { buffer:uint; name:string; // For debugging and importing back into tensorflow. quantization:QuantizationParameters; // Optional. + + is_variable:bool = false; } // A list of builtin operators. Builtin operators are slightly faster than custom @@ -148,6 +151,13 @@ enum BuiltinOperator : byte { SPARSE_TO_DENSE = 68, TILE = 69, EXPAND_DIMS = 70, + EQUAL = 71, + NOT_EQUAL = 72, + LOG = 73, + SUM=74, + SQRT = 75, + RSQRT = 76, + SHAPE = 77, } // Options for the builtin operators. @@ -178,7 +188,7 @@ union BuiltinOptions { BatchToSpaceNDOptions, SpaceToBatchNDOptions, TransposeOptions, - MeanOptions, + ReducerOptions, SubOptions, DivOptions, SqueezeOptions, @@ -204,6 +214,9 @@ union BuiltinOptions { SparseToDenseOptions, TileOptions, ExpandDimsOptions, + EqualOptions, + NotEqualOptions, + ShapeOptions, } enum Padding : byte { SAME, VALID } @@ -403,7 +416,7 @@ table TransposeOptions { table ExpOptions { } -table MeanOptions { +table ReducerOptions { keep_dims: bool; } @@ -478,6 +491,17 @@ table SparseToDenseOptions { validate_indices:bool; } +table EqualOptions { +} + +table NotEqualOptions { +} + +table ShapeOptions { + // Optional output type of the operation (int32 or int64). Defaults to int32. + out_type : TensorType; +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { @@ -509,6 +533,16 @@ table Operator { builtin_options:BuiltinOptions; custom_options:[ubyte]; custom_options_format:CustomOptionsFormat; + + // A list of booleans indicating the input tensors which are being mutated by + // this operator.(e.g. used by RNN and LSTM). + // For example, if the "inputs" array refers to 5 tensors and the second and + // fifth are mutable variables, then this list will contain + // [false, true, false, false, true]. + // + // If the list is empty, no variable is mutated in this operator. + // The list either has the same length as `inputs`, or is empty. + mutating_variable_inputs:[bool]; } // The root type, defining a subgraph, which typically represents an entire diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index b1beb39b28e79ea8ea3fafc31e0d6dfc91c0f6ce..8c0660dfe201ac7ad0f45b6fd234c213a06416b6 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -127,8 +127,8 @@ struct TransposeOptionsT; struct ExpOptions; struct ExpOptionsT; -struct MeanOptions; -struct MeanOptionsT; +struct ReducerOptions; +struct ReducerOptionsT; struct SqueezeOptions; struct SqueezeOptionsT; @@ -187,6 +187,15 @@ struct ExpandDimsOptionsT; struct SparseToDenseOptions; struct SparseToDenseOptionsT; +struct EqualOptions; +struct EqualOptionsT; + +struct NotEqualOptions; +struct NotEqualOptionsT; + +struct ShapeOptions; +struct ShapeOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -210,11 +219,12 @@ enum TensorType { TensorType_INT64 = 4, TensorType_STRING = 5, TensorType_BOOL = 6, + TensorType_INT16 = 7, TensorType_MIN = TensorType_FLOAT32, - TensorType_MAX = TensorType_BOOL + TensorType_MAX = TensorType_INT16 }; -inline TensorType (&EnumValuesTensorType())[7] { +inline TensorType (&EnumValuesTensorType())[8] { static TensorType values[] = { TensorType_FLOAT32, TensorType_FLOAT16, @@ -222,7 +232,8 @@ inline TensorType (&EnumValuesTensorType())[7] { TensorType_UINT8, TensorType_INT64, TensorType_STRING, - TensorType_BOOL + TensorType_BOOL, + TensorType_INT16 }; return values; } @@ -236,6 +247,7 @@ inline const char **EnumNamesTensorType() { "INT64", "STRING", "BOOL", + "INT16", nullptr }; return names; @@ -317,11 +329,18 @@ enum BuiltinOperator { BuiltinOperator_SPARSE_TO_DENSE = 68, BuiltinOperator_TILE = 69, BuiltinOperator_EXPAND_DIMS = 70, + BuiltinOperator_EQUAL = 71, + BuiltinOperator_NOT_EQUAL = 72, + BuiltinOperator_LOG = 73, + BuiltinOperator_SUM = 74, + BuiltinOperator_SQRT = 75, + BuiltinOperator_RSQRT = 76, + BuiltinOperator_SHAPE = 77, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_EXPAND_DIMS + BuiltinOperator_MAX = BuiltinOperator_SHAPE }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[70] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[77] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -392,7 +411,14 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[70] { BuiltinOperator_TRANSPOSE_CONV, BuiltinOperator_SPARSE_TO_DENSE, BuiltinOperator_TILE, - BuiltinOperator_EXPAND_DIMS + BuiltinOperator_EXPAND_DIMS, + BuiltinOperator_EQUAL, + BuiltinOperator_NOT_EQUAL, + BuiltinOperator_LOG, + BuiltinOperator_SUM, + BuiltinOperator_SQRT, + BuiltinOperator_RSQRT, + BuiltinOperator_SHAPE }; return values; } @@ -470,6 +496,13 @@ inline const char **EnumNamesBuiltinOperator() { "SPARSE_TO_DENSE", "TILE", "EXPAND_DIMS", + "EQUAL", + "NOT_EQUAL", + "LOG", + "SUM", + "SQRT", + "RSQRT", + "SHAPE", nullptr }; return names; @@ -508,7 +541,7 @@ enum BuiltinOptions { BuiltinOptions_BatchToSpaceNDOptions = 24, BuiltinOptions_SpaceToBatchNDOptions = 25, BuiltinOptions_TransposeOptions = 26, - BuiltinOptions_MeanOptions = 27, + BuiltinOptions_ReducerOptions = 27, BuiltinOptions_SubOptions = 28, BuiltinOptions_DivOptions = 29, BuiltinOptions_SqueezeOptions = 30, @@ -534,11 +567,14 @@ enum BuiltinOptions { BuiltinOptions_SparseToDenseOptions = 50, BuiltinOptions_TileOptions = 51, BuiltinOptions_ExpandDimsOptions = 52, + BuiltinOptions_EqualOptions = 53, + BuiltinOptions_NotEqualOptions = 54, + BuiltinOptions_ShapeOptions = 55, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_ExpandDimsOptions + BuiltinOptions_MAX = BuiltinOptions_ShapeOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[53] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[56] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -567,7 +603,7 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[53] { BuiltinOptions_BatchToSpaceNDOptions, BuiltinOptions_SpaceToBatchNDOptions, BuiltinOptions_TransposeOptions, - BuiltinOptions_MeanOptions, + BuiltinOptions_ReducerOptions, BuiltinOptions_SubOptions, BuiltinOptions_DivOptions, BuiltinOptions_SqueezeOptions, @@ -592,7 +628,10 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[53] { BuiltinOptions_TransposeConvOptions, BuiltinOptions_SparseToDenseOptions, BuiltinOptions_TileOptions, - BuiltinOptions_ExpandDimsOptions + BuiltinOptions_ExpandDimsOptions, + BuiltinOptions_EqualOptions, + BuiltinOptions_NotEqualOptions, + BuiltinOptions_ShapeOptions }; return values; } @@ -626,7 +665,7 @@ inline const char **EnumNamesBuiltinOptions() { "BatchToSpaceNDOptions", "SpaceToBatchNDOptions", "TransposeOptions", - "MeanOptions", + "ReducerOptions", "SubOptions", "DivOptions", "SqueezeOptions", @@ -652,6 +691,9 @@ inline const char **EnumNamesBuiltinOptions() { "SparseToDenseOptions", "TileOptions", "ExpandDimsOptions", + "EqualOptions", + "NotEqualOptions", + "ShapeOptions", nullptr }; return names; @@ -770,8 +812,8 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_TransposeOptions; }; -template<> struct BuiltinOptionsTraits { - static const BuiltinOptions enum_value = BuiltinOptions_MeanOptions; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ReducerOptions; }; template<> struct BuiltinOptionsTraits { @@ -874,6 +916,18 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_ExpandDimsOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_EqualOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_NotEqualOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ShapeOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -1113,13 +1167,13 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_TransposeOptions ? reinterpret_cast(value) : nullptr; } - MeanOptionsT *AsMeanOptions() { - return type == BuiltinOptions_MeanOptions ? - reinterpret_cast(value) : nullptr; + ReducerOptionsT *AsReducerOptions() { + return type == BuiltinOptions_ReducerOptions ? + reinterpret_cast(value) : nullptr; } - const MeanOptionsT *AsMeanOptions() const { - return type == BuiltinOptions_MeanOptions ? - reinterpret_cast(value) : nullptr; + const ReducerOptionsT *AsReducerOptions() const { + return type == BuiltinOptions_ReducerOptions ? + reinterpret_cast(value) : nullptr; } SubOptionsT *AsSubOptions() { return type == BuiltinOptions_SubOptions ? @@ -1321,6 +1375,30 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_ExpandDimsOptions ? reinterpret_cast(value) : nullptr; } + EqualOptionsT *AsEqualOptions() { + return type == BuiltinOptions_EqualOptions ? + reinterpret_cast(value) : nullptr; + } + const EqualOptionsT *AsEqualOptions() const { + return type == BuiltinOptions_EqualOptions ? + reinterpret_cast(value) : nullptr; + } + NotEqualOptionsT *AsNotEqualOptions() { + return type == BuiltinOptions_NotEqualOptions ? + reinterpret_cast(value) : nullptr; + } + const NotEqualOptionsT *AsNotEqualOptions() const { + return type == BuiltinOptions_NotEqualOptions ? + reinterpret_cast(value) : nullptr; + } + ShapeOptionsT *AsShapeOptions() { + return type == BuiltinOptions_ShapeOptions ? + reinterpret_cast(value) : nullptr; + } + const ShapeOptionsT *AsShapeOptions() const { + return type == BuiltinOptions_ShapeOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -1626,9 +1704,11 @@ struct TensorT : public flatbuffers::NativeTable { uint32_t buffer; std::string name; std::unique_ptr quantization; + bool is_variable; TensorT() : type(TensorType_FLOAT32), - buffer(0) { + buffer(0), + is_variable(false) { } }; @@ -1639,7 +1719,8 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_TYPE = 6, VT_BUFFER = 8, VT_NAME = 10, - VT_QUANTIZATION = 12 + VT_QUANTIZATION = 12, + VT_IS_VARIABLE = 14 }; const flatbuffers::Vector *shape() const { return GetPointer *>(VT_SHAPE); @@ -1656,6 +1737,9 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const QuantizationParameters *quantization() const { return GetPointer(VT_QUANTIZATION); } + bool is_variable() const { + return GetField(VT_IS_VARIABLE, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SHAPE) && @@ -1666,6 +1750,7 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.Verify(name()) && VerifyOffset(verifier, VT_QUANTIZATION) && verifier.VerifyTable(quantization()) && + VerifyField(verifier, VT_IS_VARIABLE) && verifier.EndTable(); } TensorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -1691,6 +1776,9 @@ struct TensorBuilder { void add_quantization(flatbuffers::Offset quantization) { fbb_.AddOffset(Tensor::VT_QUANTIZATION, quantization); } + void add_is_variable(bool is_variable) { + fbb_.AddElement(Tensor::VT_IS_VARIABLE, static_cast(is_variable), 0); + } explicit TensorBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -1709,12 +1797,14 @@ inline flatbuffers::Offset CreateTensor( TensorType type = TensorType_FLOAT32, uint32_t buffer = 0, flatbuffers::Offset name = 0, - flatbuffers::Offset quantization = 0) { + flatbuffers::Offset quantization = 0, + bool is_variable = false) { TensorBuilder builder_(_fbb); builder_.add_quantization(quantization); builder_.add_name(name); builder_.add_buffer(buffer); builder_.add_shape(shape); + builder_.add_is_variable(is_variable); builder_.add_type(type); return builder_.Finish(); } @@ -1725,14 +1815,16 @@ inline flatbuffers::Offset CreateTensorDirect( TensorType type = TensorType_FLOAT32, uint32_t buffer = 0, const char *name = nullptr, - flatbuffers::Offset quantization = 0) { + flatbuffers::Offset quantization = 0, + bool is_variable = false) { return tflite::CreateTensor( _fbb, shape ? _fbb.CreateVector(*shape) : 0, type, buffer, name ? _fbb.CreateString(name) : 0, - quantization); + quantization, + is_variable); } flatbuffers::Offset CreateTensor(flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -3777,16 +3869,16 @@ inline flatbuffers::Offset CreateExpOptions( flatbuffers::Offset CreateExpOptions(flatbuffers::FlatBufferBuilder &_fbb, const ExpOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -struct MeanOptionsT : public flatbuffers::NativeTable { - typedef MeanOptions TableType; +struct ReducerOptionsT : public flatbuffers::NativeTable { + typedef ReducerOptions TableType; bool keep_dims; - MeanOptionsT() + ReducerOptionsT() : keep_dims(false) { } }; -struct MeanOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { - typedef MeanOptionsT NativeTableType; +struct ReducerOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ReducerOptionsT NativeTableType; enum { VT_KEEP_DIMS = 4 }; @@ -3798,38 +3890,38 @@ struct MeanOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyField(verifier, VT_KEEP_DIMS) && verifier.EndTable(); } - MeanOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(MeanOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + ReducerOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ReducerOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; -struct MeanOptionsBuilder { +struct ReducerOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_keep_dims(bool keep_dims) { - fbb_.AddElement(MeanOptions::VT_KEEP_DIMS, static_cast(keep_dims), 0); + fbb_.AddElement(ReducerOptions::VT_KEEP_DIMS, static_cast(keep_dims), 0); } - explicit MeanOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + explicit ReducerOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); } - MeanOptionsBuilder &operator=(const MeanOptionsBuilder &); - flatbuffers::Offset Finish() { + ReducerOptionsBuilder &operator=(const ReducerOptionsBuilder &); + flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); - auto o = flatbuffers::Offset(end); + auto o = flatbuffers::Offset(end); return o; } }; -inline flatbuffers::Offset CreateMeanOptions( +inline flatbuffers::Offset CreateReducerOptions( flatbuffers::FlatBufferBuilder &_fbb, bool keep_dims = false) { - MeanOptionsBuilder builder_(_fbb); + ReducerOptionsBuilder builder_(_fbb); builder_.add_keep_dims(keep_dims); return builder_.Finish(); } -flatbuffers::Offset CreateMeanOptions(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateReducerOptions(flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct SqueezeOptionsT : public flatbuffers::NativeTable { typedef SqueezeOptions TableType; @@ -4781,6 +4873,140 @@ inline flatbuffers::Offset CreateSparseToDenseOptions( flatbuffers::Offset CreateSparseToDenseOptions(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct EqualOptionsT : public flatbuffers::NativeTable { + typedef EqualOptions TableType; + EqualOptionsT() { + } +}; + +struct EqualOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef EqualOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + EqualOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(EqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct EqualOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit EqualOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + EqualOptionsBuilder &operator=(const EqualOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateEqualOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + EqualOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct NotEqualOptionsT : public flatbuffers::NativeTable { + typedef NotEqualOptions TableType; + NotEqualOptionsT() { + } +}; + +struct NotEqualOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef NotEqualOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + NotEqualOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(NotEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct NotEqualOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit NotEqualOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + NotEqualOptionsBuilder &operator=(const NotEqualOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateNotEqualOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + NotEqualOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateNotEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ShapeOptionsT : public flatbuffers::NativeTable { + typedef ShapeOptions TableType; + TensorType out_type; + ShapeOptionsT() + : out_type(TensorType_FLOAT32) { + } +}; + +struct ShapeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ShapeOptionsT NativeTableType; + enum { + VT_OUT_TYPE = 4 + }; + TensorType out_type() const { + return static_cast(GetField(VT_OUT_TYPE, 0)); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_OUT_TYPE) && + verifier.EndTable(); + } + ShapeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ShapeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ShapeOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_out_type(TensorType out_type) { + fbb_.AddElement(ShapeOptions::VT_OUT_TYPE, static_cast(out_type), 0); + } + explicit ShapeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ShapeOptionsBuilder &operator=(const ShapeOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateShapeOptions( + flatbuffers::FlatBufferBuilder &_fbb, + TensorType out_type = TensorType_FLOAT32) { + ShapeOptionsBuilder builder_(_fbb); + builder_.add_out_type(out_type); + return builder_.Finish(); +} + +flatbuffers::Offset CreateShapeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -4879,6 +5105,7 @@ struct OperatorT : public flatbuffers::NativeTable { BuiltinOptionsUnion builtin_options; std::vector custom_options; CustomOptionsFormat custom_options_format; + std::vector mutating_variable_inputs; OperatorT() : opcode_index(0), custom_options_format(CustomOptionsFormat_FLEXBUFFERS) { @@ -4894,7 +5121,8 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_BUILTIN_OPTIONS_TYPE = 10, VT_BUILTIN_OPTIONS = 12, VT_CUSTOM_OPTIONS = 14, - VT_CUSTOM_OPTIONS_FORMAT = 16 + VT_CUSTOM_OPTIONS_FORMAT = 16, + VT_MUTATING_VARIABLE_INPUTS = 18 }; uint32_t opcode_index() const { return GetField(VT_OPCODE_INDEX, 0); @@ -4990,8 +5218,8 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const TransposeOptions *builtin_options_as_TransposeOptions() const { return builtin_options_type() == BuiltinOptions_TransposeOptions ? static_cast(builtin_options()) : nullptr; } - const MeanOptions *builtin_options_as_MeanOptions() const { - return builtin_options_type() == BuiltinOptions_MeanOptions ? static_cast(builtin_options()) : nullptr; + const ReducerOptions *builtin_options_as_ReducerOptions() const { + return builtin_options_type() == BuiltinOptions_ReducerOptions ? static_cast(builtin_options()) : nullptr; } const SubOptions *builtin_options_as_SubOptions() const { return builtin_options_type() == BuiltinOptions_SubOptions ? static_cast(builtin_options()) : nullptr; @@ -5068,12 +5296,24 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const ExpandDimsOptions *builtin_options_as_ExpandDimsOptions() const { return builtin_options_type() == BuiltinOptions_ExpandDimsOptions ? static_cast(builtin_options()) : nullptr; } + const EqualOptions *builtin_options_as_EqualOptions() const { + return builtin_options_type() == BuiltinOptions_EqualOptions ? static_cast(builtin_options()) : nullptr; + } + const NotEqualOptions *builtin_options_as_NotEqualOptions() const { + return builtin_options_type() == BuiltinOptions_NotEqualOptions ? static_cast(builtin_options()) : nullptr; + } + const ShapeOptions *builtin_options_as_ShapeOptions() const { + return builtin_options_type() == BuiltinOptions_ShapeOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } CustomOptionsFormat custom_options_format() const { return static_cast(GetField(VT_CUSTOM_OPTIONS_FORMAT, 0)); } + const flatbuffers::Vector *mutating_variable_inputs() const { + return GetPointer *>(VT_MUTATING_VARIABLE_INPUTS); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_OPCODE_INDEX) && @@ -5087,6 +5327,8 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyOffset(verifier, VT_CUSTOM_OPTIONS) && verifier.Verify(custom_options()) && VerifyField(verifier, VT_CUSTOM_OPTIONS_FORMAT) && + VerifyOffset(verifier, VT_MUTATING_VARIABLE_INPUTS) && + verifier.Verify(mutating_variable_inputs()) && verifier.EndTable(); } OperatorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -5198,8 +5440,8 @@ template<> inline const TransposeOptions *Operator::builtin_options_as inline const MeanOptions *Operator::builtin_options_as() const { - return builtin_options_as_MeanOptions(); +template<> inline const ReducerOptions *Operator::builtin_options_as() const { + return builtin_options_as_ReducerOptions(); } template<> inline const SubOptions *Operator::builtin_options_as() const { @@ -5302,6 +5544,18 @@ template<> inline const ExpandDimsOptions *Operator::builtin_options_as inline const EqualOptions *Operator::builtin_options_as() const { + return builtin_options_as_EqualOptions(); +} + +template<> inline const NotEqualOptions *Operator::builtin_options_as() const { + return builtin_options_as_NotEqualOptions(); +} + +template<> inline const ShapeOptions *Operator::builtin_options_as() const { + return builtin_options_as_ShapeOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -5326,6 +5580,9 @@ struct OperatorBuilder { void add_custom_options_format(CustomOptionsFormat custom_options_format) { fbb_.AddElement(Operator::VT_CUSTOM_OPTIONS_FORMAT, static_cast(custom_options_format), 0); } + void add_mutating_variable_inputs(flatbuffers::Offset> mutating_variable_inputs) { + fbb_.AddOffset(Operator::VT_MUTATING_VARIABLE_INPUTS, mutating_variable_inputs); + } explicit OperatorBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -5346,8 +5603,10 @@ inline flatbuffers::Offset CreateOperator( BuiltinOptions builtin_options_type = BuiltinOptions_NONE, flatbuffers::Offset builtin_options = 0, flatbuffers::Offset> custom_options = 0, - CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS) { + CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS, + flatbuffers::Offset> mutating_variable_inputs = 0) { OperatorBuilder builder_(_fbb); + builder_.add_mutating_variable_inputs(mutating_variable_inputs); builder_.add_custom_options(custom_options); builder_.add_builtin_options(builtin_options); builder_.add_outputs(outputs); @@ -5366,7 +5625,8 @@ inline flatbuffers::Offset CreateOperatorDirect( BuiltinOptions builtin_options_type = BuiltinOptions_NONE, flatbuffers::Offset builtin_options = 0, const std::vector *custom_options = nullptr, - CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS) { + CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS, + const std::vector *mutating_variable_inputs = nullptr) { return tflite::CreateOperator( _fbb, opcode_index, @@ -5375,7 +5635,8 @@ inline flatbuffers::Offset CreateOperatorDirect( builtin_options_type, builtin_options, custom_options ? _fbb.CreateVector(*custom_options) : 0, - custom_options_format); + custom_options_format, + mutating_variable_inputs ? _fbb.CreateVector(*mutating_variable_inputs) : 0); } flatbuffers::Offset CreateOperator(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -5746,6 +6007,7 @@ inline void Tensor::UnPackTo(TensorT *_o, const flatbuffers::resolver_function_t { auto _e = buffer(); _o->buffer = _e; }; { auto _e = name(); if (_e) _o->name = _e->str(); }; { auto _e = quantization(); if (_e) _o->quantization = std::unique_ptr(_e->UnPack(_resolver)); }; + { auto _e = is_variable(); _o->is_variable = _e; }; } inline flatbuffers::Offset Tensor::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TensorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -5761,13 +6023,15 @@ inline flatbuffers::Offset CreateTensor(flatbuffers::FlatBufferBuilder & auto _buffer = _o->buffer; auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name); auto _quantization = _o->quantization ? CreateQuantizationParameters(_fbb, _o->quantization.get(), _rehasher) : 0; + auto _is_variable = _o->is_variable; return tflite::CreateTensor( _fbb, _shape, _type, _buffer, _name, - _quantization); + _quantization, + _is_variable); } inline Conv2DOptionsT *Conv2DOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -6691,28 +6955,28 @@ inline flatbuffers::Offset CreateExpOptions(flatbuffers::FlatBufferB _fbb); } -inline MeanOptionsT *MeanOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { - auto _o = new MeanOptionsT(); +inline ReducerOptionsT *ReducerOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ReducerOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void MeanOptions::UnPackTo(MeanOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void ReducerOptions::UnPackTo(ReducerOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; { auto _e = keep_dims(); _o->keep_dims = _e; }; } -inline flatbuffers::Offset MeanOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { - return CreateMeanOptions(_fbb, _o, _rehasher); +inline flatbuffers::Offset ReducerOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateReducerOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateMeanOptions(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateReducerOptions(flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const MeanOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ReducerOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _keep_dims = _o->keep_dims; - return tflite::CreateMeanOptions( + return tflite::CreateReducerOptions( _fbb, _keep_dims); } @@ -7196,6 +7460,78 @@ inline flatbuffers::Offset CreateSparseToDenseOptions(flat _validate_indices); } +inline EqualOptionsT *EqualOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new EqualOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void EqualOptions::UnPackTo(EqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset EqualOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateEqualOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const EqualOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateEqualOptions( + _fbb); +} + +inline NotEqualOptionsT *NotEqualOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new NotEqualOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void NotEqualOptions::UnPackTo(NotEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset NotEqualOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateNotEqualOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateNotEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const NotEqualOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateNotEqualOptions( + _fbb); +} + +inline ShapeOptionsT *ShapeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ShapeOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ShapeOptions::UnPackTo(ShapeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = out_type(); _o->out_type = _e; }; +} + +inline flatbuffers::Offset ShapeOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateShapeOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateShapeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ShapeOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _out_type = _o->out_type; + return tflite::CreateShapeOptions( + _fbb, + _out_type); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -7244,6 +7580,7 @@ inline void Operator::UnPackTo(OperatorT *_o, const flatbuffers::resolver_functi { auto _e = builtin_options(); if (_e) _o->builtin_options.value = BuiltinOptionsUnion::UnPack(_e, builtin_options_type(), _resolver); }; { auto _e = custom_options(); if (_e) { _o->custom_options.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->custom_options[_i] = _e->Get(_i); } } }; { auto _e = custom_options_format(); _o->custom_options_format = _e; }; + { auto _e = mutating_variable_inputs(); if (_e) { _o->mutating_variable_inputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->mutating_variable_inputs[_i] = _e->Get(_i) != 0; } } }; } inline flatbuffers::Offset Operator::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -7261,6 +7598,7 @@ inline flatbuffers::Offset CreateOperator(flatbuffers::FlatBufferBuild auto _builtin_options = _o->builtin_options.Pack(_fbb); auto _custom_options = _o->custom_options.size() ? _fbb.CreateVector(_o->custom_options) : 0; auto _custom_options_format = _o->custom_options_format; + auto _mutating_variable_inputs = _o->mutating_variable_inputs.size() ? _fbb.CreateVector(_o->mutating_variable_inputs) : 0; return tflite::CreateOperator( _fbb, _opcode_index, @@ -7269,7 +7607,8 @@ inline flatbuffers::Offset CreateOperator(flatbuffers::FlatBufferBuild _builtin_options_type, _builtin_options, _custom_options, - _custom_options_format); + _custom_options_format, + _mutating_variable_inputs); } inline SubGraphT *SubGraph::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -7486,8 +7825,8 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } - case BuiltinOptions_MeanOptions: { - auto ptr = reinterpret_cast(obj); + case BuiltinOptions_ReducerOptions: { + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_SubOptions: { @@ -7590,6 +7929,18 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_EqualOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_NotEqualOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ShapeOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -7712,8 +8063,8 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } - case BuiltinOptions_MeanOptions: { - auto ptr = reinterpret_cast(obj); + case BuiltinOptions_ReducerOptions: { + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_SubOptions: { @@ -7816,6 +8167,18 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_EqualOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_NotEqualOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ShapeOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -7926,9 +8289,9 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateTransposeOptions(_fbb, ptr, _rehasher).Union(); } - case BuiltinOptions_MeanOptions: { - auto ptr = reinterpret_cast(value); - return CreateMeanOptions(_fbb, ptr, _rehasher).Union(); + case BuiltinOptions_ReducerOptions: { + auto ptr = reinterpret_cast(value); + return CreateReducerOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_SubOptions: { auto ptr = reinterpret_cast(value); @@ -8030,6 +8393,18 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateExpandDimsOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_EqualOptions: { + auto ptr = reinterpret_cast(value); + return CreateEqualOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_NotEqualOptions: { + auto ptr = reinterpret_cast(value); + return CreateNotEqualOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ShapeOptions: { + auto ptr = reinterpret_cast(value); + return CreateShapeOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -8140,8 +8515,8 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new TransposeOptionsT(*reinterpret_cast(u.value)); break; } - case BuiltinOptions_MeanOptions: { - value = new MeanOptionsT(*reinterpret_cast(u.value)); + case BuiltinOptions_ReducerOptions: { + value = new ReducerOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SubOptions: { @@ -8244,6 +8619,18 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new ExpandDimsOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_EqualOptions: { + value = new EqualOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_NotEqualOptions: { + value = new NotEqualOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ShapeOptions: { + value = new ShapeOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -8381,8 +8768,8 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } - case BuiltinOptions_MeanOptions: { - auto ptr = reinterpret_cast(value); + case BuiltinOptions_ReducerOptions: { + auto ptr = reinterpret_cast(value); delete ptr; break; } @@ -8511,6 +8898,21 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_EqualOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_NotEqualOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ShapeOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/contrib/lite/simple_memory_arena.cc b/tensorflow/contrib/lite/simple_memory_arena.cc index 2f2004f56bcad5b56f9dd6d4bc824ec14d79e795..4eaf6f1bfe76efc1e6737d03d58be9bc87bb849d 100644 --- a/tensorflow/contrib/lite/simple_memory_arena.cc +++ b/tensorflow/contrib/lite/simple_memory_arena.cc @@ -36,6 +36,12 @@ TfLiteStatus SimpleMemoryArena::Allocate(TfLiteContext* context, ArenaAlloc* new_alloc) { TF_LITE_ENSURE(context, alignment < arena_alignment_); + if (size == 0) { + new_alloc->offset = 0; + new_alloc->size = 0; + return kTfLiteOk; + } + size_t current_top = 0; if (!allocs_.empty()) { @@ -75,6 +81,10 @@ TfLiteStatus SimpleMemoryArena::Allocate(TfLiteContext* context, TfLiteStatus SimpleMemoryArena::Deallocate(TfLiteContext* context, const ArenaAlloc& alloc) { + if (alloc.size == 0) { + return kTfLiteOk; + } + int erased_allocs_count = 0; auto it = allocs_.begin(); while (it != allocs_.end()) { @@ -122,7 +132,11 @@ TfLiteStatus SimpleMemoryArena::ResolveAlloc(TfLiteContext* context, char** output_ptr) { TF_LITE_ENSURE(context, committed_); TF_LITE_ENSURE(context, output_ptr != nullptr); - *output_ptr = underlying_buffer_aligned_ptr_ + alloc.offset; + if (alloc.size == 0) { + *output_ptr = nullptr; + } else { + *output_ptr = underlying_buffer_aligned_ptr_ + alloc.offset; + } return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/simple_memory_arena.h b/tensorflow/contrib/lite/simple_memory_arena.h index 5faf78b59e3755d22e4e866d433e622baa6c66c1..f738315cf2f91403f9dcb6fa9e66b40bd70495aa 100644 --- a/tensorflow/contrib/lite/simple_memory_arena.h +++ b/tensorflow/contrib/lite/simple_memory_arena.h @@ -39,7 +39,8 @@ struct ArenaAlloc { // This small class is responsible for allocating, deallocating and reusing // dynamic memory from a common underlying buffer. The arena can be used in // scenarios when the pattern of memory allocations and deallocations is -// repetitive, e.g. running NN inference in multiple iterations. +// repetitive, e.g. running NN inference in multiple iterations. Note that +// zero-sized allocations are explicitly allowed, and will resolve to null. class SimpleMemoryArena { public: explicit SimpleMemoryArena(size_t arena_alignment) diff --git a/tensorflow/contrib/lite/simple_memory_arena_test.cc b/tensorflow/contrib/lite/simple_memory_arena_test.cc index 4444f642eb75c563c57762d095e454ac63d836c6..60d4d5e768aeda958574422e1c36a7cc2f6a1429 100644 --- a/tensorflow/contrib/lite/simple_memory_arena_test.cc +++ b/tensorflow/contrib/lite/simple_memory_arena_test.cc @@ -43,6 +43,47 @@ TEST(SimpleMemoryArenaTest, BasicArenaOperations) { EXPECT_EQ(allocs[5].offset, 1024); } +TEST(SimpleMemoryArenaTest, BasicZeroAlloc) { + TfLiteContext context; + SimpleMemoryArena arena(64); + ArenaAlloc alloc; + + // Zero-sized allocs should have a 0 offset and size. + ASSERT_EQ(arena.Allocate(&context, 32, 0, &alloc), kTfLiteOk); + EXPECT_EQ(alloc.offset, 0); + EXPECT_EQ(alloc.size, 0); + + // Deallocation of zero-sized allocs should always succeed (even redundantly). + ASSERT_EQ(arena.Deallocate(&context, alloc), kTfLiteOk); + ASSERT_EQ(arena.Deallocate(&context, alloc), kTfLiteOk); + + // The zero-sized alloc should resolve to null. + char* resolved_ptr = nullptr; + ASSERT_EQ(arena.Commit(&context), kTfLiteOk); + ASSERT_EQ(arena.ResolveAlloc(&context, alloc, &resolved_ptr), kTfLiteOk); + EXPECT_EQ(resolved_ptr, nullptr); +} + +TEST(SimpleMemoryArenaTest, InterleavedZeroAlloc) { + TfLiteContext context; + SimpleMemoryArena arena(64); + ArenaAlloc allocs[4]; + + // Interleave some zero and non-zero-sized allocations and deallocations. + ASSERT_EQ(arena.Allocate(&context, 32, 2047, &allocs[0]), kTfLiteOk); + ASSERT_EQ(arena.Allocate(&context, 32, 0, &allocs[1]), kTfLiteOk); + ASSERT_EQ(arena.Allocate(&context, 32, 1023, &allocs[2]), kTfLiteOk); + ASSERT_EQ(arena.Deallocate(&context, allocs[1]), kTfLiteOk); + ASSERT_EQ(arena.Deallocate(&context, allocs[2]), kTfLiteOk); + ASSERT_EQ(arena.Allocate(&context, 32, 2047, &allocs[3]), kTfLiteOk); + + // Deallocation of a zero-sized alloc should not impact the allocator offsets. + EXPECT_EQ(allocs[0].offset, 0); + EXPECT_EQ(allocs[1].offset, 0); + EXPECT_EQ(allocs[2].offset, 2048); + EXPECT_EQ(allocs[3].offset, 2048); +} + TEST(SimpleMemoryArenaTest, TestAfterClear) { TfLiteContext context; SimpleMemoryArena arena(64); diff --git a/tensorflow/contrib/lite/string_util.cc b/tensorflow/contrib/lite/string_util.cc index a89776b29f895fe82ee71efe00c0949c58c109df..a316a40b62d89189da43768d448acdf5bbeca129 100644 --- a/tensorflow/contrib/lite/string_util.cc +++ b/tensorflow/contrib/lite/string_util.cc @@ -105,7 +105,7 @@ void DynamicBuffer::WriteToTensor(TfLiteTensor* tensor) { dims->data[0] = offset_.size() - 1; // Store number of strings. TfLiteTensorReset(tensor->type, tensor->name, dims, tensor->params, tensor_buffer, bytes, kTfLiteDynamic, tensor->allocation, - tensor); + tensor->is_variable, tensor); } int GetStringCount(const char* raw_buffer) { diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index 80e4c5a4dde4702229887593afc5ffeef339176d..b823c97f38e7660652aa0ce3538b11de59dc9aea 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -20,11 +20,15 @@ load( size = "large", srcs = ["generated_examples_zip_test.cc"], args = [ - "--zip_file_path=$(location :zip_%s)" % test_name, - # TODO(angerson) We may be able to add an external unzip binary instead - # of relying on an existing one for OSS builds. - "--unzip_binary_path=/usr/bin/unzip", - ], + ] + select({ + "//tensorflow:android": [], + "//conditions:default": [ + "--zip_file_path=$(location :zip_%s)" % test_name, + # TODO(angerson) We may be able to add an external unzip binary instead + # of relying on an existing one for OSS builds. + "--unzip_binary_path=/usr/bin/unzip", + ], + }), data = [ ":zip_%s" % test_name, ], diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 9bb7a4600dde9b59eef171cc535f1b95af4e553f..c4d2d7ca52ad9b3652682e3d5127d11246b14005 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -58,10 +58,11 @@ from tensorflow.python.ops import rnn parser = argparse.ArgumentParser(description="Script to generate TFLite tests.") parser.add_argument("output_path", help="Directory where the outputs will be go.") -parser.add_argument("--zip_to_output", - type=str, - help="Particular zip to output.", - required=False) +parser.add_argument( + "--zip_to_output", + type=str, + help="Particular zip to output.", + required=True) parser.add_argument("--toco", type=str, help="Path to toco tool.", @@ -97,8 +98,6 @@ KNOWN_BUGS = { r"fully_connected.*transpose_.=True": "67586970", # Softmax graphs are too complex. r"softmax.*dim=0": "67749831", - # SpaceToDepth only supports float32. - r"space_to_depth.*(float16|int32|uint8|int64)": "68018134", # BatchToSpaceND only supports 4D tensors. r"batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\]": "70594733", # Div will use floordiv. @@ -138,7 +137,7 @@ def toco_options(data_types, Returns: the options in a string. """ - shape_str = ":".join([",".join(str(y) for y in x) for x in shapes]) + shape_str = ":".join([",".join(str(y) for y in x) for x in shapes if x]) inference_type = "FLOAT" # TODO(ahentz): if we get multi-input quantization to work we need this # to change @@ -835,6 +834,12 @@ def make_mean_tests(zip_path): return make_reduce_tests(tf.reduce_mean)(zip_path) +def make_sum_tests(zip_path): + """Make a set of tests to do sum.""" + + return make_reduce_tests(tf.reduce_sum)(zip_path) + + def make_exp_tests(zip_path): """Make a set of tests to do exp.""" @@ -1540,6 +1545,32 @@ def make_reshape_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_shape_tests(zip_path): + """Make a set of tests to do shape.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32], + "input_shape": [[], [0], [1, 1, 1, 3], [2, 3, 4, 5], [5, 5], [10]], + "out_type": [tf.int32, tf.int64], + }] + + def build_graph(parameters): + """Build the topk op testing graph.""" + # Note that we intentionally leave out the shape from the input placeholder + # to prevent the Shape operation from being optimized out during conversion. + input_value = tf.placeholder(dtype=parameters["input_dtype"], name="input") + out = tf.shape(input_value, out_type=parameters["out_type"]) + return [input_value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_dtype"], + parameters["input_shape"]) + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_resize_bilinear_tests(zip_path): """Make a set of tests to do resize_bilinear.""" @@ -1621,7 +1652,7 @@ def make_space_to_depth_tests(zip_path): """Make a set of tests to do space_to_depth.""" test_parameters = [{ - "dtype": [tf.float32, tf.float16, tf.int32, tf.uint8, tf.int64], + "dtype": [tf.float32, tf.int32, tf.uint8, tf.int64], "input_shape": [[2, 12, 24, 1]], "block_size": [2, 3, 4], }] @@ -2166,6 +2197,74 @@ def make_arg_max_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_equal_tests(zip_path): + """Make a set of tests to do equal.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32, tf.int64], + "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]), + ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), + ([5, 5], [1]), ([10], [2, 4, 10])], + }] + + def build_graph(parameters): + """Build the equal op testing graph.""" + input_value1 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input1", + shape=parameters["input_shape_pair"][0]) + input_value2 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input2", + shape=parameters["input_shape_pair"][1]) + out = tf.equal(input_value1, input_value2) + return [input_value1, input_value2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value1 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][0]) + input_value2 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][1]) + return [input_value1, input_value2], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_not_equal_tests(zip_path): + """Make a set of tests to do not equal.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32, tf.int64], + "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]), + ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), + ([5, 5], [1]), ([10], [2, 4, 10])], + }] + + def build_graph(parameters): + """Build the not euqal op testing graph.""" + input_value1 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input1", + shape=parameters["input_shape_pair"][0]) + input_value2 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input2", + shape=parameters["input_shape_pair"][1]) + out = tf.not_equal(input_value1, input_value2) + return [input_value1, input_value2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value1 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][0]) + input_value2 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][1]) + return [input_value1, input_value2], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_greater_tests(zip_path): """Make a set of tests to do greater.""" @@ -2353,30 +2452,54 @@ def make_neg_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def _make_elementwise_tests(op): + """Make a set of tests to do element-wise operations.""" + + def f(zip_path): + """Actual function that generates examples.""" + test_parameters = [{ + "input_dtype": [tf.float32], + "input_shape": [[1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]], + }] + + def build_graph(parameters): + """Build the unary op testing graph.""" + input_value = tf.placeholder( + dtype=parameters["input_dtype"], + name="input1", + shape=parameters["input_shape"]) + out = op(input_value) + return [input_value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_dtype"], + parameters["input_shape"]) + return [input_value], sess.run( + outputs, feed_dict={inputs[0]: input_value}) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + return f + + def make_sin_tests(zip_path): """Make a set of tests to do sin.""" + return _make_elementwise_tests(tf.sin)(zip_path) - test_parameters = [{ - "input_dtype": [tf.float32], - "input_shape": [[1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]], - }] - def build_graph(parameters): - """Build the sin op testing graph.""" - input_value = tf.placeholder( - dtype=parameters["input_dtype"], - name="input1", - shape=parameters["input_shape"]) - out = tf.sin(input_value) - return [input_value], [out] +def make_log_tests(zip_path): + """Make a set of tests to do log.""" + return _make_elementwise_tests(tf.log)(zip_path) - def build_inputs(parameters, sess, inputs, outputs): - input_value = create_tensor_data(parameters["input_dtype"], - parameters["input_shape"]) - return [input_value], sess.run( - outputs, feed_dict={inputs[0]: input_value}) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_sqrt_tests(zip_path): + """Make a set of tests to do sqrt.""" + return _make_elementwise_tests(tf.sqrt)(zip_path) + + +def make_rsqrt_tests(zip_path): + """Make a set of tests to do 1/sqrt.""" + return _make_elementwise_tests(tf.rsqrt)(zip_path) def make_where_tests(zip_path): diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index e85020448a572650c6a70d8b4dcb4e73faf0f8c8..8a59d756f8dbbcefc930b5285c1ced8ce6b08845 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -36,7 +36,12 @@ bool FLAGS_ignore_known_bugs = true; // TODO(b/71769302) zip_files_dir should have a more accurate default, if // possible string* FLAGS_zip_file_path = new string("./"); +#ifndef __ANDROID__ string* FLAGS_unzip_binary_path = new string("/usr/bin/unzip"); +#else +string* FLAGS_unzip_binary_path = new string("/system/bin/unzip"); +#endif +bool FLAGS_use_nnapi = false; } // namespace // TensorFlow system environment for file system called. @@ -212,7 +217,7 @@ TEST_P(OpsTest, RunZipTests) { std::ifstream tflite_stream(tflite_test_case); ASSERT_TRUE(tflite_stream.is_open()) << tflite_test_case; - tflite::testing::TfLiteDriver test_driver(/*use_nnapi=*/true); + tflite::testing::TfLiteDriver test_driver(FLAGS_use_nnapi); test_driver.SetModelBaseDir(tflite_dir); string bug_number; @@ -273,7 +278,10 @@ int main(int argc, char** argv) { "Required: Location of the test zip file."), tensorflow::Flag("unzip_binary_path", tflite::testing::FLAGS_unzip_binary_path, - "Required: Location of a suitable unzip binary.")}; + "Required: Location of a suitable unzip binary."), + tensorflow::Flag("use_nnapi", &tflite::testing::FLAGS_use_nnapi, + "Whether to enable the NNAPI delegate")}; + bool success = tensorflow::Flags::Parse(&argc, argv, flags); if (!success || (argc == 2 && !strcmp(argv[1], "--helpfull"))) { fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str()); @@ -281,6 +289,8 @@ int main(int argc, char** argv) { } ::tflite::LogToStderr(); + // TODO(mikie): googletest arguments do not work - maybe the tensorflow flags + // parser removes them? ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc index fc28faf52405b300dc6e4f0aab33122bb5e98f12..4d08fb545801521213890a4f5a9b010de57b27cd 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.cc +++ b/tensorflow/contrib/lite/testing/tflite_driver.cc @@ -163,6 +163,7 @@ void TfLiteDriver::LoadModel(const string& bin_file_path) { Invalidate("Failed build interpreter"); return; } + interpreter_->UseNNAPI(use_nnapi_); must_allocate_tensors_ = true; } @@ -284,9 +285,11 @@ bool TfLiteDriver::CheckResults() { } void TfLiteDriver::ResetLSTMStateTensors() { - // This is a workaround for initializing state tensors for LSTM. - // TODO(ycling): Refactoring and find a better way to initialize state - // tensors. Maybe write the reset instructions into the test data. + interpreter_->ResetVariableTensorsToZero(); + + // Below is a workaround for initializing state tensors for LSTM. + // TODO(ycling): Remove the code below after nobody is using the 18-inputs + // definition. for (auto node_index : interpreter_->execution_plan()) { const auto& node_and_reg = interpreter_->node_and_registration(node_index); const auto& node = node_and_reg->first; @@ -296,19 +299,12 @@ void TfLiteDriver::ResetLSTMStateTensors() { const auto* params = reinterpret_cast(node.builtin_data); if (params->kernel_type == kTfLiteLSTMFullKernel && - node.outputs->size >= 2) { + node.inputs->size == 18 && node.outputs->size >= 2) { // The first 2 outputs of LSTM are state tensors. for (int i = 0; i < 2; ++i) { int node_index = node.outputs->data[i]; ResetTensor(node_index); } - } else if (params->kernel_type == kTfLiteLSTMBasicKernel && - node.inputs->size == 5) { - // The 2th and 5th inputs are state tensors. - for (int i : {1, 4}) { - int node_index = node.inputs->data[i]; - ResetTensor(node_index); - } } } } diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index 7ea4f32ef694f3b0dc9c030b9440268ac79848aa..dd05c484fabf4d87dc12b39940a71677af4023e2 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -213,6 +213,7 @@ cc_library( "graph_transformations/convert_squeeze_to_reshape.cc", "graph_transformations/convert_trivial_addn_to_add.cc", "graph_transformations/convert_trivial_stack_to_reshape.cc", + "graph_transformations/convert_trivial_tile_to_concat.cc", "graph_transformations/convert_trivial_transpose_to_reshape.cc", "graph_transformations/create_im2col_arrays.cc", "graph_transformations/dequantize.cc", @@ -224,6 +225,7 @@ cc_library( "graph_transformations/fuse_activation_functions.cc", "graph_transformations/fuse_binary_into_following_affine.cc", "graph_transformations/fuse_binary_into_preceding_affine.cc", + "graph_transformations/fuse_broadcast_into_following_binary.cc", "graph_transformations/graph_transformations.cc", "graph_transformations/hardcode_min_max.cc", "graph_transformations/identify_dilated_conv.cc", @@ -293,7 +295,6 @@ cc_library( "graph_transformations/resolve_tensorflow_matmul.cc", "graph_transformations/resolve_tensorflow_merge.cc", "graph_transformations/resolve_tensorflow_switch.cc", - "graph_transformations/resolve_tensorflow_tile.cc", "graph_transformations/resolve_transpose_attributes.cc", "graph_transformations/unfuse_activation_functions.cc", "graph_transformations/unpartition_embedding_lookup.cc", @@ -374,6 +375,7 @@ tf_cc_test( ":toco_tooling", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "@com_google_googletest//:gtest_main", ], @@ -411,6 +413,7 @@ tf_cc_test( deps = [ ":model", ":tooling_util", + "//tensorflow/core:lib", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc index 8913b5c3ea962725ef2bed73e670e8f0b988a591..6877fb237c0514a972589ac0301647104f5ed7ed 100644 --- a/tensorflow/contrib/lite/toco/dump_graphviz.cc +++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc @@ -146,6 +146,7 @@ NodeProperties GetPropertiesForArray(const Model& model, NodeProperties node_properties; node_properties.color = GetColorForArray(model, array_name); node_properties.label = absl::StrReplaceAll(array_name, {{"/", "/\\n"}}); + node_properties.log2_buffer_size = 0.0f; // Append array shape to the label. auto& array = model.GetArray(array_name); @@ -165,9 +166,12 @@ NodeProperties GetPropertiesForArray(const Model& model, } node_properties.label += "]"; - int buffer_size = RequiredBufferSizeForShape(array.shape()); - node_properties.log2_buffer_size = - std::log2(static_cast(buffer_size)); + int buffer_size = 0; + if (IsValid(array.shape())) { + buffer_size = RequiredBufferSizeForShape(array.shape()); + node_properties.log2_buffer_size = + std::log2(static_cast(buffer_size)); + } if (array.buffer) { const auto& array = model.GetArray(array_name); @@ -200,8 +204,6 @@ NodeProperties GetPropertiesForArray(const Model& model, AppendF(&node_properties.label, "}"); } } - } else { - node_properties.log2_buffer_size = 0.0f; } if (array.minmax) { @@ -225,7 +227,7 @@ NodeProperties GetPropertiesForArray(const Model& model, NodeProperties GetPropertiesForOperator(const Operator& op) { NodeProperties node_properties; - if (op.type == OperatorType::kTensorFlowUnsupported) { + if (op.type == OperatorType::kUnsupported) { node_properties.label = static_cast(op).tensorflow_op; } else { diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index 99f0c81a1bd8b720bfb31e82a12a3098f7eee1dd..6b78f1c05ee777e0d456cd70d07b58ff51271fec 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -494,7 +494,7 @@ void ConvertTransposeConvOperator(const Model& model, const auto& weights_array = model.GetArray(weights_array_name); CHECK(weights_array.buffer->type == ArrayDataType::kFloat); ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI, - AxesOrder::kHWIO, tensorflow_graph); + AxesOrder::kHWOI, tensorflow_graph); auto& strides = (*conv2d_op->mutable_attr())["strides"]; strides.mutable_list()->add_i(1); strides.mutable_list()->add_i(src_op.stride_height); @@ -735,8 +735,7 @@ void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op, GraphDef* tensorflow_graph) { string softmax_input; Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]); - if (providing_op != nullptr && - providing_op->type == OperatorType::kTensorFlowReshape) { + if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) { softmax_input = src_op.inputs[0]; } else { // Insert a reshape operator that reduces the dimensions down to the 2 that @@ -776,8 +775,7 @@ void ConvertLogSoftmaxOperator(const Model& model, GraphDef* tensorflow_graph) { string softmax_input; Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]); - if (providing_op != nullptr && - providing_op->type == OperatorType::kTensorFlowReshape) { + if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) { softmax_input = src_op.inputs[0]; } else { // Insert a reshape operator that reduces the dimensions down to the 2 that @@ -1047,6 +1045,18 @@ void ConvertSqrtOperator(const TensorFlowSqrtOperator& src_op, (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT); } +void ConvertRsqrtOperator(const Model& model, + const TensorFlowRsqrtOperator& src_op, + GraphDef* tensorflow_graph) { + auto* rsqrt_op = tensorflow_graph->add_node(); + rsqrt_op->set_op("Rsqrt"); + rsqrt_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 1); + *rsqrt_op->add_input() = src_op.inputs[0]; + const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*rsqrt_op->mutable_attr())["T"].set_type(data_type); +} + void ConvertSplitOperator(const Model& model, const TensorFlowSplitOperator& src_op, GraphDef* tensorflow_graph) { @@ -1687,6 +1697,22 @@ void ConvertSelectOperator(const Model& model, const SelectOperator& src_op, (*sub_op->mutable_attr())["T"].set_type(data_type); } +void ConvertTileOperator(const Model& model, + const TensorFlowTileOperator& src_op, + GraphDef* tensorflow_graph) { + auto* tile_op = tensorflow_graph->add_node(); + tile_op->set_op("Tile"); + tile_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *tile_op->add_input() = src_op.inputs[0]; + *tile_op->add_input() = src_op.inputs[1]; + const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*tile_op->mutable_attr())["T"].set_type(data_type); + const auto multiples_data_type = + GetTensorFlowDataType(model, src_op.inputs[1]); + (*tile_op->mutable_attr())["Tmultiples"].set_type(multiples_data_type); +} + void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op, GraphDef* tensorflow_graph) { auto* topk_op = tensorflow_graph->add_node(); @@ -1827,20 +1853,24 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertConcatenationOperator( model, static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowReshape) { + } else if (src_op.type == OperatorType::kReshape) { ConvertTensorFlowReshapeOperator( model, static_cast(src_op), tensorflow_graph); } else if (src_op.type == OperatorType::kL2Pool) { ConvertL2PoolOperator(static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowSquare) { + } else if (src_op.type == OperatorType::kSquare) { ConvertSquareOperator(static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowSqrt) { + } else if (src_op.type == OperatorType::kSqrt) { ConvertSqrtOperator(static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowSplit) { + } else if (src_op.type == OperatorType::kRsqrt) { + ConvertRsqrtOperator(model, + static_cast(src_op), + tensorflow_graph); + } else if (src_op.type == OperatorType::kSplit) { ConvertSplitOperator(model, static_cast(src_op), tensorflow_graph); @@ -1884,11 +1914,11 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kSub) { ConvertSubOperator(model, static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowMinimum) { + } else if (src_op.type == OperatorType::kMinimum) { ConvertTensorFlowMinimumOperator( model, static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowMaximum) { + } else if (src_op.type == OperatorType::kMaximum) { ConvertTensorFlowMaximumOperator( model, static_cast(src_op), tensorflow_graph); @@ -1907,7 +1937,7 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kTranspose) { ConvertTransposeOperator( model, static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowShape) { + } else if (src_op.type == OperatorType::kShape) { ConvertTensorFlowShapeOperator( model, static_cast(src_op), tensorflow_graph); @@ -1938,17 +1968,25 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertRandomUniformOperator( model, static_cast(src_op), tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowGreater) { + } else if (src_op.type == OperatorType::kEqual) { + ConvertComparisonOperator(model, src_op, "Equal", tensorflow_graph); + } else if (src_op.type == OperatorType::kNotEqual) { + ConvertComparisonOperator(model, src_op, "NotEqual", tensorflow_graph); + } else if (src_op.type == OperatorType::kGreater) { ConvertComparisonOperator(model, src_op, "Greater", tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowGreaterEqual) { + } else if (src_op.type == OperatorType::kGreaterEqual) { ConvertComparisonOperator(model, src_op, "GreaterEqual", tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowLess) { + } else if (src_op.type == OperatorType::kLess) { ConvertComparisonOperator(model, src_op, "Less", tensorflow_graph); - } else if (src_op.type == OperatorType::kTensorFlowLessEqual) { + } else if (src_op.type == OperatorType::kLessEqual) { ConvertComparisonOperator(model, src_op, "LessEqual", tensorflow_graph); } else if (src_op.type == OperatorType::kSelect) { ConvertSelectOperator(model, static_cast(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kTile) { + ConvertTileOperator(model, + static_cast(src_op), + tensorflow_graph); } else { LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); } diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md index 7680cdd344814bf6cbc7bbe11c915f220642d55d..a4b4e14ba6a9ceb56fffacbe97a934242bca9407 100644 --- a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md +++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md @@ -26,8 +26,6 @@ Table of contents: * [Convert a TensorFlow Lite FlatBuffer back into TensorFlow GraphDef format](#to-graphdef) * [Logging](#logging) - * [Standard logging](#standard-logging) - * [Verbose logging](#verbose-logging) * [Graph "video" logging](#graph-video-logging) * [Graph visualizations](#graph-visualizations) * [Using --output_format=GRAPHVIZ_DOT](#using-output-formatgraphviz-dot) @@ -41,7 +39,7 @@ FlatBuffer to perform floating-point inference. ``` bazel run --config=opt \ - third_party/tensorflow/contrib/lite/toco:toco -- \ + //tensorflow/contrib/lite/toco:toco -- \ --savedmodel_directory=/tmp/saved_model \ --output_file=/tmp/foo.tflite ``` @@ -277,49 +275,6 @@ bazel run --config=opt \ ## Logging -### Standard logging - -The converter generates some informative log messages during processing. The -easiest way to view them is to add `--logtostderr` to command lines as seen in -the following example. - -``` -curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ - | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ - --output_file=/tmp/foo.tflite \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TFLITE \ - --inference_type=FLOAT \ - --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 \ - --logtostderr -``` - -After some initialization messages, we get the following informative messages: - -``` -I1101 21:51:33.297475 5339 graph_transformations.cc:39] Before general graph transformations: 416 operators, 583 arrays (0 quantized) -I1101 21:51:33.308972 5339 graph_transformations.cc:39] After general graph transformations pass 1: 31 operators, 89 arrays (0 quantized) -I1101 21:51:33.309204 5339 graph_transformations.cc:39] Before dequantization graph transformations: 31 operators, 89 arrays (0 quantized) -I1101 21:51:33.309368 5339 allocate_transient_arrays.cc:312] Total transient array allocated size: 1048576 bytes, theoretical optimal value: 786432 bytes. -I1101 21:51:33.309484 5339 toco_tooling.cc:249] Estimated count of arithmetic ops: 0.099218 billion (note that a multiply-add is counted as 2 ops). -``` - -### Verbose logging - -For debugging purposes, the converter supports two levels of verbose logging, -which can be set by passing a `--v=` flag: - -* For `--v=1`, the converter generates text dumps of the graph at various - points during processing as well as log messages about every graph - transformation that took place. -* For `--v=2`, the converter additionally generates log messages about graph - transformations that were considered but not performed. - ### Graph "video" logging When `--dump_graphviz=` is used (see the section on [graph diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md index a8381169b83838b6a72b63bf17ca7e3fda6adf2d..8085ae07489816c38677ff792e7ac71f1a75fa71 100644 --- a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md +++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md @@ -209,16 +209,6 @@ have. ## Logging flags -The following are standard Google logging flags: - -* `--logtostderr` redirects Google logging to standard error, typically making - it visible in a terminal. -* `--v` sets verbose logging levels (for debugging purposes). Defined levels: - * `--v=1`: log all graph transformations that did make a change on the - graph. - * `--v=2`: log all graph transformations that did *not* make a change on - the graph. - The following flags allow to generate graph visualizations of the actual graph at various points during transformations: diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md index 5071361bfdb4edd7711f6f777c8c384ee5802cb6..a7841a685528fb18bb08f1943278339a2daec16a 100644 --- a/tensorflow/contrib/lite/toco/g3doc/python_api.md +++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md @@ -138,7 +138,8 @@ out = tf.fake_quant_with_min_max_args(val, min=0., max=1., name="output") with tf.Session() as sess: converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out]) converter.inference_type = tf.contrib.lite.constants.QUANTIZED_UINT8 - converter.quantized_input_stats = {"img" : (0., 1.)} # mean, std_dev + input_arrays = converter.get_input_arrays() + converter.quantized_input_stats = {input_arrays[0] : (0., 1.)} # mean, std_dev tflite_model = converter.convert() open("converted_model.tflite", "wb").write(tflite_model) ``` diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc index 0fffab574ddd8ad75ec07ae4442f363a36ed289e..1ea83abf8eb1b49f649e81def29857094cd0c2d7 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc @@ -38,6 +38,16 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) { // Depthwise conv does not support dilation return false; } + auto& input_array = model->GetArray(conv_op->inputs[0]); + if (!input_array.has_shape()) { + // Shapes not propagated yet + return false; + } + if (input_array.shape().dims(3) != 1) { + // Not a pure convolution: Conv does accumulation across the depth + // dimension. + return false; + } auto& weights_array = model->GetArray(conv_op->inputs[1]); if (!weights_array.buffer) { // Yield until the weights are resolved as a constant array. @@ -46,11 +56,6 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) { if (weights_array.data_type != ArrayDataType::kFloat) { return false; } - if (weights_array.shape().dims(3) != 1) { - // Not a pure convolution: Conv does accumulation across the depth - // dimension. - return false; - } // At this point we know we have a pure conv. Rewrite it as DepthwiseConv. AddMessageF( "%s is purely convolutional (input/weights depth is 1), replacing it by " diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc new file mode 100644 index 0000000000000000000000000000000000000000..b689be07926ecd9be4cc317735dc88eb90950e13 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc @@ -0,0 +1,94 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) { + auto tile_it = model->operators.begin() + op_index; + if (tile_it->get()->type != OperatorType::kTile) { + return false; + } + auto* tile_op = static_cast(tile_it->get()); + + const auto& input_array = model->GetArray(tile_op->inputs[0]); + const auto& multiples_array = model->GetArray(tile_op->inputs[1]); + const auto& output_array = model->GetArray(tile_op->outputs[0]); + if (!input_array.has_shape() || !multiples_array.has_shape() || + !output_array.has_shape()) { + // Yield until PropagateFixedSizes has been run on this op. + return false; + } + // Note: We can assume we have error checked inputs in PropagateFixedSizes. + + if (!multiples_array.buffer) { + // Yield until the multiples is constant. + return false; + } + std::vector const& multiples = + multiples_array.GetBuffer().data; + + // We can simplify the tile if only a single dimension is being multiplied. + // It then just becomes a concat along that dimension. + int non_one_dims = 0; + int concat_axis = 0; + for (int i = 0; i < multiples.size(); ++i) { + if (multiples[i] != 1) { + ++non_one_dims; + concat_axis = i; + } + } + if (non_one_dims != 1) { + // The tile is non-trivial. Good luck. + AddMessageF("Tile %s is non-trivial (has more than one multiply dimension)", + LogName(*tile_op)); + return false; + } + + // The tile is like a concat. + AddMessageF("Simplifying %s to a Concat along a single axis %d", + LogName(*tile_op), concat_axis); + + auto* concat_op = new ConcatenationOperator; + + // Copy input and output. + // Note that we multiply out the input by the number of times requested. + for (int i = 0; i < multiples[concat_axis]; ++i) { + concat_op->inputs.push_back(tile_op->inputs[0]); + } + concat_op->axis = concat_axis; + concat_op->outputs = tile_op->outputs; + + // Delete multiples array if unused. + if (IsDiscardableArray(*model, tile_op->inputs[1]) && + CountOpsWithInput(*model, tile_op->inputs[1]) == 1) { + model->EraseArray(tile_op->inputs[1]); + } + + // Replace the operator in the graph. + const auto concat_it = model->operators.emplace(tile_it, concat_op); + tile_it = concat_it + 1; + CHECK_EQ(tile_it->get(), tile_op); + model->operators.erase(tile_it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc index 076415ece8c1039caa32e947fe54ab3e101bec9e..1e68cd678bce6c27f1852a5ae0c13362d8938cdd 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc @@ -25,17 +25,12 @@ limitations under the License. namespace toco { -bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) { - auto conv_it = model->operators.begin() + op_index; - if (conv_it->get()->type != OperatorType::kConv) { - return false; - } - auto* conv_op = static_cast(conv_it->get()); - if (conv_op->outputs.size() == 2) { +bool ProcessConvOperator(Model* model, ConvOperator* op) { + if (op->outputs.size() == 2) { // We already have an im2col array return false; } - const auto& weights_array = model->GetArray(conv_op->inputs[1]); + const auto& weights_array = model->GetArray(op->inputs[1]); if (!weights_array.has_shape()) { // We need to yield until weights dims have been resolved, because // from the weights dims we determine whether an im2col array is @@ -45,25 +40,52 @@ bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) { const auto& weights_shape = weights_array.shape(); const int kheight = weights_shape.dims(1); const int kwidth = weights_shape.dims(2); - if (kwidth == 1 && kheight == 1 && conv_op->stride_width == 1 && - conv_op->stride_height == 1) { - // 1x1 unstrided conv does not need an im2col array. + if (kwidth == 1 && kheight == 1 && op->stride_width == 1 && + op->stride_height == 1 && op->dilation_width_factor == 1 && + op->dilation_height_factor == 1) { + // 1x1 unstrided undilated conv does not need an im2col array. return false; } // Create the im2col array. - CHECK_EQ(conv_op->outputs.size(), 1); + CHECK_EQ(op->outputs.size(), 1); const string& im2col_array_name = - AvailableArrayName(*model, conv_op->inputs[0] + "_im2col"); + AvailableArrayName(*model, op->inputs[0] + "_im2col"); model->GetOrCreateArray(im2col_array_name); - conv_op->outputs.push_back(im2col_array_name); - AddMessageF( - "Created an im2col array for %s, with %dx%d kernel and stride_width=%d, " - "stride_height=%d", - LogName(*conv_op), kwidth, kheight, conv_op->stride_width, - conv_op->stride_height); + op->outputs.push_back(im2col_array_name); return true; } +bool ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) { + if (op->outputs.size() == 2) { + // We already have an im2col array + return false; + } + + // Always create an im2col array for transpose_conv. + CHECK_EQ(op->outputs.size(), 1); + const string& im2col_array_name = AvailableArrayName( + *model, op->inputs[TransposeConvOperator::DATA_INPUT] + "_im2col"); + model->GetOrCreateArray(im2col_array_name); + op->outputs.push_back(im2col_array_name); + + return true; +} + +bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) { + auto it = model->operators.begin() + op_index; + auto* op = it->get(); + + switch (op->type) { + case OperatorType::kConv: + return ProcessConvOperator(model, static_cast(op)); + case OperatorType::kTransposeConv: + return ProcessTransposeConvOperator( + model, static_cast(op)); + default: + return false; + } +} + } // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc index 498c864bde6d656c8318e981204cb42cb3a4d03f..2c7ffe488477ef1a544dfe6f36a6e0d1ac40aa96 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc @@ -111,7 +111,7 @@ bool DequantizeArray(const string& array_name, auto* op_outputting_array = GetOpWithOutput(*model, array_name); if (op_outputting_array) { - if (op_outputting_array->type == OperatorType::kTensorFlowReshape) { + if (op_outputting_array->type == OperatorType::kReshape) { return true; } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc new file mode 100644 index 0000000000000000000000000000000000000000..874d8def571fbce4219de15285c8df6fd2487a9a --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +// Returns true if the given op is strictly a broadcasting operation. +// This is commonly seen as a Concat of the same input multiple times, and is +// often generated from Tile ops that were converted via the +// convert_trivial_tile_to_concat transformation. +bool IsBroadcastingOp(const Model& model, Operator* op) { + // Concatenation of identical inputs is usually a broadcast. + if (op->type == OperatorType::kConcatenation) { + // Verify that all inputs are the same. + for (int i = 1; i < op->inputs.size(); ++i) { + if (op->inputs[i] != op->inputs[0]) { + return false; + } + } + return true; + } + + // There are other things we could look for (Stack/etc) when needed. + return false; +} + +} // namespace + +// Finds an operation that looks like a broadcast (concat of the same sources +// along the last dimension) and drops it by relying on the ability of certain +// binary ops to perform an implicit broadcast. +bool FuseBroadcastIntoFollowingBinary::Run(Model* model, std::size_t op_index) { + const auto binary_it = model->operators.begin() + op_index; + auto* binary_op = binary_it->get(); + + // Test for binary ops of types that we know how to resolve + if (binary_op->inputs.size() != 2) { + return false; + } + if (binary_op->type != OperatorType::kAdd && + binary_op->type != OperatorType::kMul && + binary_op->type != OperatorType::kSub && + binary_op->type != OperatorType::kDiv) { + return false; + } + + // NOTE: either of these ops may be nullptr if the input array is constant. + Operator* const op[2] = { + GetOpWithOutput(*model, binary_op->inputs[0]), + GetOpWithOutput(*model, binary_op->inputs[1]), + }; + + // Check whether either input is a broadcast-like concat. + bool is_op_0_broadcast = op[0] && IsBroadcastingOp(*model, op[0]); + bool is_op_1_broadcast = op[1] && IsBroadcastingOp(*model, op[1]); + if (!is_op_0_broadcast && !is_op_1_broadcast) { + // Neither input is a broadcast-looking thing. + AddMessageF("Neither input looks broadcasty"); + return false; + } else if (is_op_0_broadcast && is_op_1_broadcast) { + AddMessageF( + "Unable to fuse broadcast into %s as both inputs (%s, %s) are " + "broadcasts", + LogName(*binary_op), op[0] ? LogName(*op[0]) : "(?)", + op[1] ? LogName(*op[1]) : "(?)"); + return false; + } + int broadcast_index = is_op_0_broadcast ? 0 : 1; + + // Just pull out the input of the broadcast op and pass it directly to the + // binary op. + AddMessageF("Fusing broadcast op %s into the following binary %s", + LogName(*op[broadcast_index]), LogName(*binary_op)); + binary_op->inputs[broadcast_index] = op[broadcast_index]->inputs[0]; + + // We leave the broadcast op in; it'll get cleaned up if it's not used later. + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index 1bc7557d46cfa5e1b27468d2da271e75fd491d36..62a09acdfbb553161e480051aa506486b9adec47 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -117,12 +117,14 @@ DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise) DECLARE_GRAPH_TRANSFORMATION(ConvertSqueezeToReshape) DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialAddNToAdd) DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialStackToReshape) +DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTileToConcat) DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTransposeToReshape) DECLARE_GRAPH_TRANSFORMATION(ConvertReorderAxes) DECLARE_GRAPH_TRANSFORMATION(EnsureBiasVectors) DECLARE_GRAPH_TRANSFORMATION(FuseActivationFunctions) DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoFollowingAffine) DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoPrecedingAffine) +DECLARE_GRAPH_TRANSFORMATION(FuseBroadcastIntoFollowingBinary) DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Normalization) DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Pool) DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell) @@ -165,7 +167,6 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMatMul) DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMerge) DECLARE_GRAPH_TRANSFORMATION(ResolveSqueezeAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSwitch) -DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowTile) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantConcatenation) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantReshape) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTranspose) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc index d63ee7c9519d169a2f44ec1afe81125217db8976..82a4308ecb134d28c37f4519ae783b50bf35477a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -353,7 +353,7 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { changed = HardcodeMinMaxForConcatenation(model, op); break; - case OperatorType::kTensorFlowSplit: + case OperatorType::kSplit: changed = HardcodeMinMaxForSplit(model, op); break; @@ -362,9 +362,11 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { changed = HardcodeMinMaxForAverageOrMaxPool(model, op); break; + case OperatorType::kResizeBilinear: + case OperatorType::kSlice: case OperatorType::kStridedSlice: case OperatorType::kSqueeze: - case OperatorType::kTensorFlowReshape: + case OperatorType::kReshape: case OperatorType::kPad: case OperatorType::kGather: case OperatorType::kTranspose: diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc index ae3301f467de5714230e731b4bab87ddc1637201..d49857cfc22ecaf5feb06b39a42187f8adb61d50 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc @@ -90,12 +90,13 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) { } // Conv Op - ConvOperator* conv_op = dynamic_cast( - has_expand_op ? GetOpWithInput(*model, post_stb_op->outputs[0]) - : GetOpWithInput(*model, stb_op->outputs[0])); - if (!conv_op || conv_op->type != OperatorType::kConv) { + const string& input_of_conv_op = + has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0]; + auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op); + if (conv_base_op->type != OperatorType::kConv) { return false; } + auto* conv_op = static_cast(conv_base_op); if (conv_op->inputs.size() != 2) { // The conv op must only have weights, no bias. return false; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc index 419a0776a6b987a18df059d3c1d4bf4370cd24d8..b78efd7fc3602dc2d6e03fd28d694c344b61c17c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc @@ -44,10 +44,9 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { const auto* div_or_mul_op = div_it->get(); OperatorType expected_op_type_producing_div_or_mul_input; if (div_or_mul_op->type == OperatorType::kDiv) { - expected_op_type_producing_div_or_mul_input = OperatorType::kTensorFlowSqrt; + expected_op_type_producing_div_or_mul_input = OperatorType::kSqrt; } else if (div_or_mul_op->type == OperatorType::kMul) { - expected_op_type_producing_div_or_mul_input = - OperatorType::kTensorFlowRsqrt; + expected_op_type_producing_div_or_mul_input = OperatorType::kRsqrt; } else { return false; } @@ -75,8 +74,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { Operator* add_op = nullptr; Operator* op_producing_add_input = nullptr; if (op_producing_sqrt_or_rsqrt_input->type == OperatorType::kAdd || - op_producing_sqrt_or_rsqrt_input->type == - OperatorType::kTensorFlowMaximum) { + op_producing_sqrt_or_rsqrt_input->type == OperatorType::kMaximum) { add_op = op_producing_sqrt_or_rsqrt_input; bool add_can_be_removed = false; CHECK_EQ(op_producing_sqrt_or_rsqrt_input->inputs.size(), 2); @@ -113,7 +111,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { Operator* sum_op = add_op ? op_producing_add_input : op_producing_sqrt_or_rsqrt_input; - if (sum_op->type != OperatorType::kTensorFlowSum) { + if (sum_op->type != OperatorType::kSum) { AddMessageF( "Giving up trying to identify L2Normalization subgraph: " "expected Sum op, got %s", @@ -122,7 +120,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { } Operator* square_op = GetOpWithOutput(*model, sum_op->inputs[0]); - if (square_op->type != OperatorType::kTensorFlowSquare) { + if (square_op->type != OperatorType::kSquare) { AddMessageF( "Giving up trying to identify L2Normalization subgraph: " "expected Square op, got %s", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc index e4d52476c649de53b3ab663f53ce7a5538dbb5ab..705e73779b7f74698149d5e9e56f69a371326ceb 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc @@ -41,7 +41,7 @@ std::vector>::iterator FindOperator( bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) { const auto sqrt_it = model->operators.begin() + op_index; const auto* sqrt_op = sqrt_it->get(); - if (sqrt_op->type != OperatorType::kTensorFlowSqrt) { + if (sqrt_op->type != OperatorType::kSqrt) { return false; } @@ -52,6 +52,13 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) { const Operator* square_op; Operator* prev_to_sqrt_op = GetOpWithOutput(*model, sqrt_op->inputs[0]); + if (prev_to_sqrt_op == nullptr) { + AddMessageF( + "Giving up trying to identify L2Pool subgraph: " + "expected AveragePool op, but Sqrt op has no preceding op"); + return false; + } + if (prev_to_sqrt_op->type != OperatorType::kAveragePool) { AddMessageF( "Giving up trying to identify L2Pool subgraph: " @@ -65,7 +72,7 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) { square_op = GetOpWithOutput(*model, avpool_op->inputs[0]); CHECK_EQ(square_op->inputs.size(), 1); - if (square_op->type != OperatorType::kTensorFlowSquare) { + if (square_op->type != OperatorType::kSquare) { AddMessageF( "Giving up trying to identify L2Pool subgraph: " "expected Square op, got %s", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc index e9842524c829b839b97b3453a36c41efe186efbb..910e38a6ba6e7676d03648b4d0548edaf47b7d8a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc @@ -266,26 +266,26 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) { // State remember "information" activation function Operator* fc_output_split; - if (!MatchOperatorInputs(*state_info_tanh, *model, - OperatorType::kTensorFlowSplit, &fc_output_split)) { + if (!MatchOperatorInputs(*state_info_tanh, *model, OperatorType::kSplit, + &fc_output_split)) { return false; } // State remember gate activation function Operator* tmp; - if (!MatchOperatorInputs(*state_remember_sig, *model, - OperatorType::kTensorFlowSplit, &tmp) || + if (!MatchOperatorInputs(*state_remember_sig, *model, OperatorType::kSplit, + &tmp) || (tmp != fc_output_split)) { return false; } // State forget gate activation function - if (!MatchOperatorInputs(*state_forget_sig, *model, - OperatorType::kTensorFlowSplit, &tmp) || + if (!MatchOperatorInputs(*state_forget_sig, *model, OperatorType::kSplit, + &tmp) || (tmp != fc_output_split)) { return false; } // Fully connected output activation function - if (!MatchOperatorInputs(*fc_output_sig, *model, - OperatorType::kTensorFlowSplit, &tmp) || + if (!MatchOperatorInputs(*fc_output_sig, *model, OperatorType::kSplit, + &tmp) || (tmp != fc_output_split)) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc index e6e3dfa1de9c9fdd5e759fd547d11a7b8c95d837..46d1fce50e5d6e2a74cf5461d731e46469dde5bf 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc @@ -74,6 +74,12 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) { lstm_cell_op->inputs[kInputTensor] = curr_op->inputs[LstmCellOperator::ACTIV_OUTPUT]; + // Previous states. + lstm_cell_op->inputs[kInputActivationStateTensor] = + curr_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]; + lstm_cell_op->inputs[kInputCellStateTensor] = + curr_op->inputs[LstmCellOperator::PREV_STATE_INPUT]; + // Get original weight tensor and decompose 1 tensor to 8 sub tensors. Array& kernel = model->GetArray(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT]); @@ -160,10 +166,6 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) { // Erase curr lstm op being replaced. DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT], model); DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::BIASES_INPUT], model); - DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT], - model); - DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::PREV_STATE_INPUT], - model); model->operators.erase(FindOp(*model, curr_op)); return true; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc index bddb563206f763a756685d196836fa41825cf045..94820a016622a12654e91967737e05fc91ed404c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc @@ -60,24 +60,22 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) { // Follow sequences of min+max and max+min. First get the leading op. const auto op_it = model->operators.begin() + op_index; const auto* op_0 = op_it->get(); - if (op_0->type != OperatorType::kTensorFlowMinimum && - op_0->type != OperatorType::kTensorFlowMaximum) { + if (op_0->type != OperatorType::kMinimum && + op_0->type != OperatorType::kMaximum) { return false; } // Get the paired op and ensure it's the counter to the first. const auto* op_1 = GetOpWithInput(*model, op_0->outputs[0]); if (!op_1 || - (op_1->type != OperatorType::kTensorFlowMinimum && - op_1->type != OperatorType::kTensorFlowMaximum) || + (op_1->type != OperatorType::kMinimum && + op_1->type != OperatorType::kMaximum) || op_0->type == op_1->type) { return false; } - const auto* min_op = - op_0->type == OperatorType::kTensorFlowMinimum ? op_0 : op_1; - const auto* max_op = - op_0->type == OperatorType::kTensorFlowMaximum ? op_0 : op_1; + const auto* min_op = op_0->type == OperatorType::kMinimum ? op_0 : op_1; + const auto* max_op = op_0->type == OperatorType::kMaximum ? op_0 : op_1; if (min_op->inputs.size() != 2 || max_op->inputs.size() != 2) { return false; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h index 1c32a781698ec78003ebbf9caff28557924323e5..6d8603a1133a7478647b8bcc49ea1eceba28df31 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h @@ -47,10 +47,14 @@ enum ExtendedLstmCellInputs { kOutputGateBiasTensor = 15, kProjectionWeightsTensor = 16, // Optional kProjectionBiasTensor = 17, // Optional - kExtendedLstmInputCount = 18 + kInputActivationStateTensor = 18, + // The op can handle 18 inputs or 20 inputs. + kInputCellStateTensor = 19, + kExtendedLstmInputCount = 20, }; enum ExtendedLstmCellOutputs { + // TODO(ycling): Make the 2 output state tensors optional. kOutputStateTensor = 0, kCellStateTensor = 1, kOutputTensor = 2, diff --git a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc index 5065004093434475172a39efdcfd26c10c49148b..95bc7f7d4b8b517c1cc5a73b3e85bbd985ce460f 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc @@ -106,7 +106,7 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model, std::size_t op_index) { auto it = model->operators.begin() + op_index; auto* reshape_op = ConvertOperator( - it->get(), OperatorType::kTensorFlowReshape); + it->get(), OperatorType::kReshape); if (reshape_op == nullptr) { return false; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 64096fb069d6393f1eee9ada82c20c33a3405de9..27a1049eaf830e2c690dbc68f80d37107eb76772 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -56,20 +56,22 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { // These operators unconditionally produce float outputs SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat); break; - case OperatorType::kTensorFlowLess: - case OperatorType::kTensorFlowLessEqual: - case OperatorType::kTensorFlowGreater: - case OperatorType::kTensorFlowGreaterEqual: + case OperatorType::kLess: + case OperatorType::kLessEqual: + case OperatorType::kGreater: + case OperatorType::kGreaterEqual: + case OperatorType::kEqual: + case OperatorType::kNotEqual: // These operators unconditionally produce bool outputs SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool); break; case OperatorType::kRank: - case OperatorType::kTensorFlowShape: + case OperatorType::kShape: // These operators only produce int32 outputs. SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32); break; - case OperatorType::kTensorFlowSplit: - case OperatorType::kTensorFlowConcat: + case OperatorType::kSplit: + case OperatorType::kConcat: case OperatorType::kFill: { // These operators produce an output with the same type as their 2nd input CHECK_GE(op->inputs.size(), 2); @@ -133,7 +135,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { model->GetArray(op->outputs[1]).data_type = ArrayDataType ::kInt32; break; } - case OperatorType::kTensorFlowUnsupported: { + case OperatorType::kUnsupported: { auto* unsupported_op = static_cast(op); // Some output tensors from the op could be eliminated by optimization. // This can make unsupported_op->output_data_types have more elements than diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc index 6d51fc8c31e6c86701c3dc1fd07a9a5479114738..e25125b429a7e33fa83c603eb85b931ab45ecb50 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc @@ -90,8 +90,8 @@ void ChangeArrayDataType(GraphTransformation* transformation, Array* array, bool DoesOpBlockBackwardPropagation(const Operator& op) { switch (op.type) { case OperatorType::kConcatenation: - case OperatorType::kTensorFlowConcat: - case OperatorType::kTensorFlowConcatV2: + case OperatorType::kConcat: + case OperatorType::kConcatV2: // Concat shouldn't block propagation, but we do expect that all inputs // have the same range. return false; @@ -100,9 +100,10 @@ bool DoesOpBlockBackwardPropagation(const Operator& op) { // FakeQuant so make sure we move across them. case OperatorType::kGather: // Gathers need their parameters changed to the appropriate data type. - case OperatorType::kTensorFlowReshape: + case OperatorType::kReshape: case OperatorType::kTranspose: case OperatorType::kSelect: + case OperatorType::kTile: // Reshapes and transposes don't change values. return false; default: @@ -120,10 +121,13 @@ bool DoesOpInputBlockBackwardPropagation(const Operator& op, int input_index) { // Ignore gather indices. return input_index != 0; break; - case OperatorType::kTensorFlowReshape: + case OperatorType::kReshape: case OperatorType::kTranspose: // Ignore reshape/transpose shapes/dimensions. return input_index != 0; + case OperatorType::kTile: + // Ignore tile multiples. + return input_index != 0; default: return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index adb241da3246d985cdbcf5053a5926620d8487ff..c61da203c63ae0b449c7d4b0cb63945bb3551f3a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -211,12 +211,6 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) { // might as well calculate the output shape and ensure it matches the // specified one - // Check if we have already run. - auto& output_array = model->GetArray(op->outputs[0]); - if (output_array.has_shape()) { - return; - } - // SPECIFIED OUTPUT SHAPE // The below is the specified, or prescribed output shape, _given_ to the // operator as an input. @@ -278,13 +272,23 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) { << "TransposeConv input shape must have 4 dimensions. Input \"" << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape " << toco::ShapeToString(weights_shape) << "."; - CHECK_EQ(input_shape.dims(3), weights_shape.dims(0)) + CHECK_EQ(input_shape.dims(3), weights_shape.dims(3)) << "Input shape depth and weight depth do not agree"; // Set the output shape according to the specified output shape. std::vector const& specified_output_shape = specified_output_shape_array.GetBuffer().data; + auto& output_array = model->GetArray(op->outputs[0]); *(output_array.mutable_shape()->mutable_dims()) = specified_output_shape; + + // Set im2col array dimensions if there is one. + if (op->outputs.size() == 2) { + const int input_depth = weights_shape.dims(3); + auto& im2col_array = model->GetArray(op->outputs[1]); + im2col_array.copy_shape( + Shape{specified_output_shape[0], specified_output_shape[1], + specified_output_shape[2], input_depth * kheight * kwidth}); + } } void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { @@ -321,7 +325,7 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { if (!op->depth_multiplier) { op->depth_multiplier = output_depth / input_depth; } - QCHECK_EQ(output_depth, input_depth * op->depth_multiplier) + CHECK_EQ(output_depth, input_depth * op->depth_multiplier) << "input/output depths and depth_multiplier don't match"; const int kheight = weights_shape.dims(1); @@ -568,11 +572,11 @@ void ProcessAddNOperator(Model* model, Operator* op) { bool KeepDims(const Operator& op) { switch (op.type) { - case OperatorType::kTensorFlowMin: + case OperatorType::kMin: // Reduction Min return static_cast(op).keep_dims; - case OperatorType::kTensorFlowMax: + case OperatorType::kMax: // Reduction Max return static_cast(op).keep_dims; - case OperatorType::kTensorFlowSum: + case OperatorType::kSum: return static_cast(op).keep_dims; case OperatorType::kMean: return static_cast(op).keep_dims; @@ -1505,6 +1509,48 @@ void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) { } } +void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) { + CHECK_EQ(op->inputs.size(), 2); + CHECK_EQ(op->outputs.size(), 1); + + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.has_shape()) { + // We have already run. + return; + } + + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.has_shape()) { + // Yield until input dims have been resolved. + return; + } + const auto& input_shape = input_array.shape(); + + auto& multiples_array = model->GetArray(op->inputs[1]); + if (!multiples_array.has_shape()) { + // Yield until multiples shape been resolved. + return; + } + if (!multiples_array.buffer) { + // Yield until the multiples is constant. + return; + } + CHECK(multiples_array.data_type == ArrayDataType::kInt32) + << "Tile multiples input must be int32"; + + std::vector const& multiples = + multiples_array.GetBuffer().data; + CHECK_EQ(multiples.size(), input_shape.dimensions_count()) + << "Tile multiples input " << op->inputs[1] + << " must be same length as input dimensions"; + + auto* mutable_dims = output_array.mutable_shape()->mutable_dims(); + mutable_dims->resize(multiples.size()); + for (int i = 0; i < mutable_dims->size(); ++i) { + (*mutable_dims)[i] = input_shape.dims(i) * multiples[i]; + } +} + } // namespace bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { @@ -1531,14 +1577,14 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kLogistic: case OperatorType::kTanh: case OperatorType::kLocalResponseNormalization: - case OperatorType::kTensorFlowIdentity: + case OperatorType::kIdentity: case OperatorType::kFakeQuant: case OperatorType::kNeg: - case OperatorType::kTensorFlowRsqrt: - case OperatorType::kTensorFlowSqrt: - case OperatorType::kTensorFlowSquare: - case OperatorType::kTensorFlowAll: - case OperatorType::kTensorFlowAssert: + case OperatorType::kRsqrt: + case OperatorType::kSqrt: + case OperatorType::kSquare: + case OperatorType::kAll: + case OperatorType::kAssert: case OperatorType::kCast: case OperatorType::kFloor: case OperatorType::kExp: @@ -1557,12 +1603,14 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kDiv: case OperatorType::kFloorDiv: case OperatorType::kFloorMod: - case OperatorType::kTensorFlowLess: - case OperatorType::kTensorFlowLessEqual: - case OperatorType::kTensorFlowGreater: - case OperatorType::kTensorFlowMaximum: - case OperatorType::kTensorFlowMinimum: - case OperatorType::kTensorFlowGreaterEqual: + case OperatorType::kLess: + case OperatorType::kLessEqual: + case OperatorType::kGreater: + case OperatorType::kMaximum: // Element-wise Maximum + case OperatorType::kMinimum: // Element-wise Minimum + case OperatorType::kGreaterEqual: + case OperatorType::kEqual: + case OperatorType::kNotEqual: ProcessSimpleBinaryOperator(model, op); break; case OperatorType::kAddN: @@ -1595,7 +1643,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { ProcessFullyConnectedOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowReshape: + case OperatorType::kReshape: ProcessTensorFlowReshapeOperator( model, static_cast(op)); break; @@ -1608,9 +1656,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kL2Pool: ProcessL2PoolOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowMin: - case OperatorType::kTensorFlowMax: - case OperatorType::kTensorFlowSum: + case OperatorType::kMin: // Reduction Min + case OperatorType::kMax: // Reduction Max + case OperatorType::kSum: case OperatorType::kMean: ProcessTensorFlowReductionOperator(model, op); break; @@ -1621,34 +1669,26 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { ProcessSliceOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowTile: - // We don't currently implement the propagation of fixed sizes through - // a TensorFlow Tile. - // - // Fortunately, we don't need to: so far, we have only dealt with Tile - // or Slice ops in subgraphs that are identified as L2Normalization. - // See IdentifyL2Normalization. - break; - case OperatorType::kTensorFlowSwitch: + case OperatorType::kSwitch: // We can't know the sizes of the outputs until we have resolved the // predicate, and once we have resolved the predicate, the whole // Switch node will get resolved away. // See ResolveTensorFlowSwitch. break; - case OperatorType::kTensorFlowMerge: + case OperatorType::kMerge: // No need to bother resolving TensorFlow Merge ops: other graph // transformations will remove them anyway. // See ResolveTensorFlowMerge. break; - case OperatorType::kTensorFlowSplit: + case OperatorType::kSplit: ProcessTensorFlowSplitOperator(model, static_cast(op)); break; case OperatorType::kSqueeze: ProcessSqueezeOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowConcat: - case OperatorType::kTensorFlowConcatV2: + case OperatorType::kConcat: + case OperatorType::kConcatV2: // Unimplemented, hopefully another graph transformation will // drop it or rewrite it. Concretely, either ResolveTensorFlowConcat // will resolve this node to a DepthConcatenation, or else we have @@ -1664,7 +1704,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kRank: ProcessRankOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowShape: + case OperatorType::kShape: ProcessShapeOperator(model, static_cast(op)); break; case OperatorType::kStack: @@ -1685,7 +1725,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { ProcessLstmCellOperator(model, static_cast(op)); break; case OperatorType::kBatchMatMul: - case OperatorType::kTensorFlowMatMul: + case OperatorType::kMatMul: // MatMul operators are converted to FullyConnected, after which their // shapes are propagated. break; @@ -1710,7 +1750,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kArgMax: ProcessArgMaxOperator(model, static_cast(op)); break; - case OperatorType::kTensorFlowUnsupported: + case OperatorType::kUnsupported: break; case OperatorType::kSvdf: ProcessSvdfOperator(model, static_cast(op)); @@ -1732,6 +1772,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { ProcessSparseToDenseOperator(model, static_cast(op)); break; + case OperatorType::kTile: + ProcessTileOperator(model, static_cast(op)); + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc index 142841fcc460e8a5e9e4f2333496f4ece2557275..1c61b8cb36ef9968c55f64c023fca8361162beb1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -33,7 +33,7 @@ namespace { bool SupportsQuantization(const Operator& op) { auto type = op.type; - if (type == OperatorType::kTensorFlowUnsupported) { + if (type == OperatorType::kUnsupported) { auto* unsupported = static_cast(&op); return unsupported->quantized; } @@ -42,25 +42,24 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kConcatenation || type == OperatorType::kL2Normalization || type == OperatorType::kAdd || type == OperatorType::kAveragePool || type == OperatorType::kMaxPool || - type == OperatorType::kTensorFlowMinimum || - type == OperatorType::kTensorFlowMaximum || + type == OperatorType::kMinimum || type == OperatorType::kMaximum || type == OperatorType::kLogistic || type == OperatorType::kSoftmax || - type == OperatorType::kLogSoftmax || - type == OperatorType::kTensorFlowSplit || type == OperatorType::kSub || + type == OperatorType::kLogSoftmax || type == OperatorType::kSlice || + type == OperatorType::kResizeBilinear || + type == OperatorType::kSplit || type == OperatorType::kSub || type == OperatorType::kSqueeze || type == OperatorType::kPad || - type == OperatorType::kPadV2 || - type == OperatorType::kTensorFlowReshape || + type == OperatorType::kPadV2 || type == OperatorType::kReshape || type == OperatorType::kTanh || type == OperatorType::kMul || + type == OperatorType::kSpaceToBatchND || type == OperatorType::kSpaceToDepth || type == OperatorType::kStridedSlice || type == OperatorType::kDepthToSpace || type == OperatorType::kLstmCell || type == OperatorType::kGather || type == OperatorType::kTranspose || type == OperatorType::kMean || - type == OperatorType::kTensorFlowGreater || - type == OperatorType::kTensorFlowGreaterEqual || - type == OperatorType::kTensorFlowLess || - type == OperatorType::kTensorFlowLessEqual || - type == OperatorType::kSelect; + type == OperatorType::kGreater || + type == OperatorType::kGreaterEqual || type == OperatorType::kLess || + type == OperatorType::kLessEqual || type == OperatorType::kSelect || + type == OperatorType::kArgMax; } const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) { @@ -328,12 +327,12 @@ bool ChooseQuantizationForOperatorOutput( } if ((op.type == OperatorType::kDepthToSpace) || (op.type == OperatorType::kSpaceToDepth) || - (op.type == OperatorType::kTensorFlowReshape) || - (op.type == OperatorType::kTensorFlowSplit) || + (op.type == OperatorType::kReshape) || + (op.type == OperatorType::kSplit) || (op.type == OperatorType::kConcatenation && model->flags.change_concat_input_ranges())) { int data_input_index = 0; - if (op.type == OperatorType::kTensorFlowSplit) { + if (op.type == OperatorType::kSplit) { data_input_index = 1; } // Copying and rearrangement ops should preserve the quantization parameters diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc index 35a0c465327f352863350e7a8af714d16b7be393..73ad326299bbd929afbb8dda2c41b97a126afbe1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc @@ -26,7 +26,7 @@ namespace toco { bool RemoveTensorFlowAssert::Run(Model* model, std::size_t op_index) { const auto assert_it = model->operators.begin() + op_index; const auto* assert_op = assert_it->get(); - if (assert_op->type != OperatorType::kTensorFlowAssert) { + if (assert_op->type != OperatorType::kAssert) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc index 404269bbfd9312bbbab32489783d9e4217ecbd89..7ec7752f25dad1c24b821733c0e6dafbd1cd8bf2 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc @@ -28,7 +28,7 @@ namespace toco { bool RemoveTensorFlowIdentity::Run(Model* model, std::size_t op_index) { const auto passthru_it = model->operators.begin() + op_index; const auto* passthru_op = passthru_it->get(); - if (passthru_op->type != OperatorType::kTensorFlowIdentity) { + if (passthru_op->type != OperatorType::kIdentity) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc index a950fe6442bc656b725a1f0687f4c024f4fb0f84..9f5d8b94507ec11957c3ae55ffca510eeb81ac89 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc @@ -97,7 +97,7 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation, "Cannot remove %s, neither its main input nor its output may be " "discarded", LogName(*passthru_op)); - if (passthru_op->type != OperatorType::kTensorFlowReshape && + if (passthru_op->type != OperatorType::kReshape && model->GetArray(main_input_name).has_shape()) { // We can't remove either array but we can remove the op. Converting it to // a reshape gives us some hope of later on fixing that (either in the diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc index eaee1c662b7cedb2baec7be47e12e348c3e7b25c..142c876b154755ac9c6b93e560f22ec8d6ec6563 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc @@ -47,11 +47,11 @@ bool IsTrivialMinMax(GraphTransformation* transformation, const Model& model, double clamp_min; double clamp_max; switch (op_type) { - case OperatorType::kTensorFlowMinimum: + case OperatorType::kMinimum: // Element-wise Minimum clamp_min = -std::numeric_limits::infinity(); clamp_max = clamp_value; break; - case OperatorType::kTensorFlowMaximum: + case OperatorType::kMaximum: // Element-wise Maximum clamp_min = clamp_value; clamp_max = std::numeric_limits::infinity(); break; @@ -72,8 +72,8 @@ bool IsTrivialMinMax(GraphTransformation* transformation, const Model& model, bool RemoveTrivialQuantizedMinMax::Run(Model* model, std::size_t op_index) { const auto it = model->operators.begin() + op_index; auto* op = it->get(); - if ((op->type != OperatorType::kTensorFlowMinimum && - op->type != OperatorType::kTensorFlowMaximum) || + if ((op->type != OperatorType::kMinimum && + op->type != OperatorType::kMaximum) || op->inputs.size() != 2) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc index e28d8cf01eafee64e08ac2cc4b43ea7c227456c2..404f27e067402474484d3ee8e23595fb9f93a6c9 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc @@ -30,7 +30,7 @@ namespace { bool IsReshapeTrivial(const Model& model, const Operator& op, RemoveTrivialReshape* transformation) { - CHECK(op.type == OperatorType::kTensorFlowReshape); + CHECK(op.type == OperatorType::kReshape); // One way in which a reshape can be trivial is if its // output shape is == its input shape @@ -58,7 +58,7 @@ bool IsReshapeTrivial(const Model& model, const Operator& op, // is only consumed by another reshape. if (CountOpsWithInput(model, op.outputs[0]) == 1) { const auto* next_op = GetOpWithInput(model, op.outputs[0]); - if (next_op->type == OperatorType::kTensorFlowReshape) { + if (next_op->type == OperatorType::kReshape) { transformation->AddMessageF( "%s is trivial because its output is only consumed by another " "Reshape op %s", @@ -75,7 +75,7 @@ bool IsReshapeTrivial(const Model& model, const Operator& op, bool RemoveTrivialReshape::Run(Model* model, std::size_t op_index) { const auto reshape_it = model->operators.begin() + op_index; auto* reshape_op = reshape_it->get(); - if (reshape_op->type != OperatorType::kTensorFlowReshape) { + if (reshape_op->type != OperatorType::kReshape) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc index 1956ab2d2021cda84a0d715534923d6174c30dd1..dde91234a8240f4518cd105c2cc4e79102735980 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc @@ -48,7 +48,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) { for (const auto& rnn_state : model->flags.rnn_states()) { if (output == rnn_state.state_array()) { CHECK(op->type == OperatorType::kFill || - op->type == OperatorType::kTensorFlowIdentity); + op->type == OperatorType::kIdentity); found_output_as_rnn_state_array = true; break; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc index 9f5b7920cb937b021eb23fc1d5fdc3c1ff18a72d..550de83018f25a7aa4da82707fedb86434615fb0 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc @@ -37,8 +37,8 @@ bool IsElementwiseOperator(OperatorType optype) { case OperatorType::kRelu1: case OperatorType::kRelu6: case OperatorType::kTanh: - case OperatorType::kTensorFlowSqrt: - case OperatorType::kTensorFlowSquare: + case OperatorType::kSqrt: + case OperatorType::kSquare: return true; default: return false; @@ -51,7 +51,7 @@ bool IsMoveOperator(OperatorType optype) { case OperatorType::kExpandDims: case OperatorType::kSpaceToDepth: case OperatorType::kSqueeze: - case OperatorType::kTensorFlowReshape: + case OperatorType::kReshape: case OperatorType::kTranspose: return true; default: diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc index 9e7fe1b1ccd851dd998e59e75ff798f52f7c6e5a..c907a597cb719b68dbf36868a75e49a7c5181423 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc @@ -123,8 +123,8 @@ bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) { } TensorFlowReshapeOperator* reshape_op = - ConvertOperator( - reshape_it->get(), OperatorType::kTensorFlowReshape); + ConvertOperator(reshape_it->get(), + OperatorType::kReshape); if (reshape_op == nullptr) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc index 6e78653fad238085da5ba66166884093ea9b0214..f7e5aa6609bd4f7eb2a95750125e30a7803b36e1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc @@ -145,17 +145,17 @@ void EvaluateBinaryOperatorOnConstantInputs(Model* model, outval = floor(val0 / val1); } else if (binary_op->type == OperatorType::kFloorMod) { outval = val0 - (floor(val0 / val1) * val1); - } else if (binary_op->type == OperatorType::kTensorFlowMinimum) { + } else if (binary_op->type == OperatorType::kMinimum) { outval = std::min(val0, val1); - } else if (binary_op->type == OperatorType::kTensorFlowMaximum) { + } else if (binary_op->type == OperatorType::kMaximum) { outval = std::max(val0, val1); - } else if (binary_op->type == OperatorType::kTensorFlowLess) { + } else if (binary_op->type == OperatorType::kLess) { outval = val0 < val1; - } else if (binary_op->type == OperatorType::kTensorFlowLessEqual) { + } else if (binary_op->type == OperatorType::kLessEqual) { outval = val0 <= val1; - } else if (binary_op->type == OperatorType::kTensorFlowGreater) { + } else if (binary_op->type == OperatorType::kGreater) { outval = val0 > val1; - } else if (binary_op->type == OperatorType::kTensorFlowGreaterEqual) { + } else if (binary_op->type == OperatorType::kGreaterEqual) { outval = val0 >= val1; } else { LOG(FATAL) << "should not get here"; @@ -198,12 +198,12 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) { binary_op->type != OperatorType::kDiv && binary_op->type != OperatorType::kFloorDiv && binary_op->type != OperatorType::kFloorMod && - binary_op->type != OperatorType::kTensorFlowMinimum && - binary_op->type != OperatorType::kTensorFlowMaximum && - binary_op->type != OperatorType::kTensorFlowLess && - binary_op->type != OperatorType::kTensorFlowLessEqual && - binary_op->type != OperatorType::kTensorFlowGreater && - binary_op->type != OperatorType::kTensorFlowGreaterEqual) { + binary_op->type != OperatorType::kMinimum && + binary_op->type != OperatorType::kMaximum && + binary_op->type != OperatorType::kLess && + binary_op->type != OperatorType::kLessEqual && + binary_op->type != OperatorType::kGreater && + binary_op->type != OperatorType::kGreaterEqual) { return false; } CHECK_EQ(binary_op->inputs.size(), 2); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc index 7e7ad383e7789891f5396845241e70143dc8b76f..41562ab393694d76c5cb6c5df5f7df2a71f893f5 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc @@ -25,7 +25,7 @@ namespace toco { bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) { auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); - if (base_op->type != OperatorType::kTensorFlowReshape) { + if (base_op->type != OperatorType::kReshape) { return false; } const auto* op = static_cast(base_op); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc index 9ea01acd05364224ce219bed533c999793a2a2f1..8a0e3e8995839a737b5671701a97b514b0fc7bf1 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc @@ -22,8 +22,7 @@ namespace toco { bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) { const auto it = model->operators.begin() + op_index; const auto* op = it->get(); - if (!(op->type == OperatorType::kTensorFlowShape || - op->type == OperatorType::kRank)) { + if (!(op->type == OperatorType::kShape || op->type == OperatorType::kRank)) { return false; } @@ -48,7 +47,7 @@ bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) { // Compute the output CHECK(!output_array.buffer); auto& output_buffer = output_array.GetMutableBuffer(); - if (op->type == OperatorType::kTensorFlowShape) { + if (op->type == OperatorType::kShape) { // Copy the input shape into the output buffer. output_buffer.data = input_array.shape().dims(); } else if (op->type == OperatorType::kRank) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc index 69db1942cd52af810acf38a818997c71122d8500..a4d5f1923a1dffdff1ef51eb5317fa5794a8bc27 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc @@ -41,7 +41,7 @@ void Stack(Model* model, StackOperator const& op) { const auto& input_array = model->GetArray(op.inputs[i]); int input_size = RequiredBufferSizeForShape(input_array.shape()); memcpy(&output_data[dst_offset], &input_array.GetBuffer().data[0], - input_size * sizeof(Type)); + input_size * ElementSize(Type)); dst_offset += input_size; } CHECK_EQ(dst_offset, output_data.size()); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc index 1dd52e906900e997f282740404a81b9fcd21e867..6ee231465fae5127e3769bd6b9060ea60d59eb2c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc @@ -155,14 +155,7 @@ bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) { break; } - // Erase input array if no longer used - if (IsDiscardableArray(*model, op->inputs[0]) && - CountOpsWithInput(*model, op->inputs[0]) == 1) { - model->EraseArray(op->inputs[0]); - } - - // Erase the operator - model->operators.erase(it); + DeleteOpAndArraysIfUnused(model, it->get()); return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc index f6c8f79d8d3311dc2294e3ec406a184b2a16a6b5..f89ef85fdb63ca4906c7f016e86bb1f9d8a7099a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc @@ -53,13 +53,13 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { case OperatorType::kCast: case OperatorType::kLog: case OperatorType::kNeg: - case OperatorType::kTensorFlowRsqrt: - case OperatorType::kTensorFlowSqrt: - case OperatorType::kTensorFlowSquare: - case OperatorType::kTensorFlowSum: - case OperatorType::kTensorFlowMin: - case OperatorType::kTensorFlowMax: - case OperatorType::kTensorFlowReshape: + case OperatorType::kRsqrt: + case OperatorType::kSqrt: + case OperatorType::kSquare: + case OperatorType::kSum: + case OperatorType::kMin: // Reduction Min + case OperatorType::kMax: // Reduction Max + case OperatorType::kReshape: case OperatorType::kRelu6: case OperatorType::kRelu1: case OperatorType::kRelu: @@ -103,7 +103,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { // The min-max is only copied for ops that copy data without arithmetic. // In future trivial transpose, etc, can be handled here. - if (unary_op->type == OperatorType::kTensorFlowReshape) { + if (unary_op->type == OperatorType::kReshape) { CopyMinMaxFromFirstInput(*unary_op, model); } @@ -164,10 +164,10 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { } output_float_data[i] = outval; } - } else if (unary_op->type == OperatorType::kTensorFlowReshape) { + } else if (unary_op->type == OperatorType::kReshape) { CHECK(input_buffer_size == output_buffer_size); output_float_data = *input_float_data; - } else if (unary_op->type == OperatorType::kTensorFlowSum) { + } else if (unary_op->type == OperatorType::kSum) { CHECK_EQ(unary_op->inputs.size(), 2) << "Sum needs 2 inputs"; if (!IsConstantParameterArray(*model, unary_op->inputs[1])) { AddMessageF("Axis input is non-constant"); @@ -196,7 +196,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { } output_float_data[i] = sum; } - } else if (unary_op->type == OperatorType::kTensorFlowMin) { + } else if (unary_op->type == OperatorType::kMin) { // At the moment only full reduction across all dimensions is supported. // TODO(starka): Output should not be padded. for (int i = 0; i < output_dims_count; i++) { @@ -207,7 +207,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { min = std::min(min, (*input_float_data)[i]); } output_float_data[0] = min; - } else if (unary_op->type == OperatorType::kTensorFlowMax) { + } else if (unary_op->type == OperatorType::kMax) { // At the moment only full reduction across all dimensions is supported. // TODO(starka): Output should not be padded. for (int i = 0; i < output_dims_count; i++) { @@ -220,9 +220,9 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { output_float_data[0] = max; } else if (unary_op->type == OperatorType::kNeg || unary_op->type == OperatorType::kLog || - unary_op->type == OperatorType::kTensorFlowRsqrt || - unary_op->type == OperatorType::kTensorFlowSqrt || - unary_op->type == OperatorType::kTensorFlowSquare) { + unary_op->type == OperatorType::kRsqrt || + unary_op->type == OperatorType::kSqrt || + unary_op->type == OperatorType::kSquare) { // Element-wise ops. Should have perfectly matching sizes here. for (int i = 0; i < output_dims_count; i++) { CHECK_EQ(output_shape.dims(i), input_shape.dims(i)); @@ -235,11 +235,11 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { outval = -val; } else if (unary_op->type == OperatorType::kLog) { outval = std::log(val); - } else if (unary_op->type == OperatorType::kTensorFlowRsqrt) { + } else if (unary_op->type == OperatorType::kRsqrt) { outval = 1.0f / std::sqrt(val); - } else if (unary_op->type == OperatorType::kTensorFlowSqrt) { + } else if (unary_op->type == OperatorType::kSqrt) { outval = std::sqrt(val); - } else if (unary_op->type == OperatorType::kTensorFlowSquare) { + } else if (unary_op->type == OperatorType::kSquare) { outval = val * val; } else { LOG(FATAL) << "should not get here."; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc index 2e063e35548aa5e51c3bcc94a2dfc7992180d014..b615c9a545695e5d14fa5809e0c38a770f23ea24 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc @@ -28,7 +28,7 @@ namespace toco { bool ResolveReshapeAttributes::Run(Model* model, std::size_t op_index) { const auto reshape_it = model->operators.begin() + op_index; auto* reshape_op = reshape_it->get(); - if (reshape_op->type != OperatorType::kTensorFlowReshape) { + if (reshape_op->type != OperatorType::kReshape) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc index dd3e73635ae0215510f0a8d1aee487da5af35700..e8bb85704e1c750300079681b5a12f6a488b6b48 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc @@ -36,7 +36,7 @@ bool ResolveSqueezeAttributes::Run(Model* model, std::size_t op_index) { // If the output is consumed by a reshape op, it's a trivial squeeze. if (CountOpsWithInput(*model, squeeze_op->outputs[0]) == 1) { const auto* next_op = GetOpWithInput(*model, squeeze_op->outputs[0]); - if (next_op->type == OperatorType::kTensorFlowReshape) { + if (next_op->type == OperatorType::kReshape) { AddMessageF( "%s is trivial because its output is only consumed by a " "Reshape op", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc index 5c0c1e3478fa0d94104d1b76bab176b98b314c50..fa5ee899334bdf2d39a6861b0e0c4548142e9d2a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc @@ -28,8 +28,8 @@ namespace toco { bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) { auto concat_it = model->operators.begin() + op_index; const auto* tf_concat_op = concat_it->get(); - if (tf_concat_op->type != OperatorType::kTensorFlowConcat && - tf_concat_op->type != OperatorType::kTensorFlowConcatV2) { + if (tf_concat_op->type != OperatorType::kConcat && + tf_concat_op->type != OperatorType::kConcatV2) { return false; } @@ -38,7 +38,7 @@ bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) { // of inputs: in Concat,the axis is the first input, while in // ConcatV2, it is the last input. std::size_t axis_pos = 0; - if (tf_concat_op->type == OperatorType::kTensorFlowConcatV2) { + if (tf_concat_op->type == OperatorType::kConcatV2) { axis_pos = tf_concat_op->inputs.size() - 1; } const string axis_name = tf_concat_op->inputs[axis_pos]; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc index 2a236d3f98784e8244942f94d5a250b5bc00a8ad..d496f5ae5eeeca5063e23b25498b0ac450e9f946 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc @@ -26,7 +26,7 @@ namespace toco { bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { auto matmul_it = model->operators.begin() + op_index; - if (matmul_it->get()->type != OperatorType::kTensorFlowMatMul) { + if (matmul_it->get()->type != OperatorType::kMatMul) { return false; } const auto* matmul_op = @@ -97,7 +97,7 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) { // MatMul op as a FullyConnected. However, TensorFlow skips the Reshape ops if // the input doesn't need reshaping, so we can't just match (Reshape, MatMul) // pairs. - if (previous_op && previous_op->type == OperatorType::kTensorFlowReshape) { + if (previous_op && previous_op->type == OperatorType::kReshape) { AddMessageF("Combining %s and %s into %s", LogName(*previous_op), LogName(*matmul_op), LogName(*fc_op)); const auto& previous_op_output = previous_op->outputs[0]; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc index 38e0005890ac10410df4ddb5290be8fcc948c349..4edffe3d48fd880c0261b34fc407b8e2ac66ccb9 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc @@ -27,7 +27,7 @@ namespace toco { bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) { const auto merge_it = model->operators.begin() + op_index; const auto* merge_op = merge_it->get(); - if (merge_op->type != OperatorType::kTensorFlowMerge) { + if (merge_op->type != OperatorType::kMerge) { return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc index a418073441f1241a5acb1164b36f332828ea2e99..da8e7a2d1c06cf89b9708b404da7667565245f8f 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc @@ -27,7 +27,7 @@ namespace toco { bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) { const auto switch_it = model->operators.begin() + op_index; const auto* switch_op = switch_it->get(); - if (switch_op->type != OperatorType::kTensorFlowSwitch) { + if (switch_op->type != OperatorType::kSwitch) { return false; } @@ -92,7 +92,7 @@ bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) { if (*input_it == switch_op->outputs[nonselected_output_index]) { // Let us guard our assumption that only Merge nodes consume the outputs // of Switch nodes: - CHECK(other_op->type == OperatorType::kTensorFlowMerge); + CHECK(other_op->type == OperatorType::kMerge); input_it = other_op->inputs.erase(input_it); } else { ++input_it; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc deleted file mode 100644 index 1ddf54c778cd1fae7a8fce0ecb97209274e71ac0..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include -#include -#include -#include - -#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" -#include "tensorflow/contrib/lite/toco/model.h" -#include "tensorflow/contrib/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" - -namespace toco { - -namespace { - -void RemoveTileOperator(Model* model, Operator* tile_op, Operator* binary_op, - int operand_index) { - CHECK(tile_op->type == OperatorType::kTensorFlowTile); - CHECK_EQ(binary_op->inputs.size(), 2); - CHECK_EQ(tile_op->inputs.size(), 2); - const string tile_multiplier_array = tile_op->inputs[1]; - const string tile_output_array = tile_op->outputs[0]; - binary_op->inputs[operand_index] = tile_op->inputs[0]; - auto tile_it = model->operators.begin(); - for (; tile_it != model->operators.end(); ++tile_it) { - if (tile_it->get() == tile_op) { - break; - } - } - CHECK(tile_it != model->operators.end()); - CHECK(tile_it->get() == tile_op); - model->operators.erase(tile_it); - if (!CountOpsWithInput(*model, tile_multiplier_array) && - !GetOpWithOutput(*model, tile_multiplier_array)) { - model->EraseArray(tile_multiplier_array); - } - if (!CountOpsWithInput(*model, tile_output_array)) { - model->EraseArray(tile_output_array); - } -} -} // namespace - -bool ResolveTensorFlowTile::Run(Model* model, std::size_t op_index) { - const auto binary_it = model->operators.begin() + op_index; - auto* binary_op = binary_it->get(); - // Test for binary ops of types that we know how to resolve - if (binary_op->inputs.size() != 2) { - return false; - } - if (binary_op->type != OperatorType::kAdd && - binary_op->type != OperatorType::kMul && - binary_op->type != OperatorType::kSub && - binary_op->type != OperatorType::kDiv) { - return false; - } - - Operator* const op[2] = { - GetOpWithOutput(*model, binary_op->inputs[0]), - GetOpWithOutput(*model, binary_op->inputs[1]), - }; - - // In the unlikely case where both operands are Tile, we can't infer the - // output - // size without the Tile nodes, so we have to bail out. - if (op[0] && op[0]->type == OperatorType::kTensorFlowTile && op[1] && - op[1]->type == OperatorType::kTensorFlowTile) { - return false; - } - - for (int i = 0; i < 2; i++) { - if (op[i] && op[i]->type == OperatorType::kTensorFlowTile) { - // We can only remove a Tile operator is no other op than the present - // binary op was consuming its tiled output. - if (CountOpsWithInput(*model, binary_op->inputs[i]) == 1) { - AddMessageF("Removing %s", LogName(*op[i])); - RemoveTileOperator(model, op[i], binary_op, i); - return true; - } - } - } - return false; -} - -} // namespace toco diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index fa2e85440e12fe0aca46f18079ba697d654b477f..8da33e8a2278757dd53180b9347259278c447669 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h" #include "tensorflow/contrib/lite/toco/tensorflow_util.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/contrib/lite/toco/tooling_util.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function.h" @@ -44,6 +43,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" @@ -63,8 +63,6 @@ using tensorflow::TensorShapeProto; namespace toco { -using port::Status; - namespace { bool HasAttr(const NodeDef& node, const string& attr_name) { return node.attr().count(attr_name) > 0; @@ -130,6 +128,42 @@ const AttrValue::ListValue& GetListAttr(const NodeDef& node, return attr.list(); } +tensorflow::Status CheckOptionalAttr(const NodeDef& node, + const string& attr_name, + const string& expected_value) { + if (HasAttr(node, attr_name)) { + const string& value = GetStringAttr(node, attr_name); + if (value != expected_value) { + return tensorflow::errors::InvalidArgument( + "Unexpected value for attribute '" + attr_name + "'. Expected '" + + expected_value + "'"); + } + } + return tensorflow::Status::OK(); +} + +tensorflow::Status CheckOptionalAttr( + const NodeDef& node, const string& attr_name, + const tensorflow::DataType& expected_value) { + if (HasAttr(node, attr_name)) { + const tensorflow::DataType& value = GetDataTypeAttr(node, attr_name); + if (value != expected_value) { + return tensorflow::errors::InvalidArgument( + "Unexpected value for attribute '" + attr_name + "'. Expected '" + + tensorflow::DataType_Name(expected_value) + "'"); + } + } + return tensorflow::Status::OK(); +} + +template +tensorflow::Status ExpectValue(const T1& v1, const T2& v2, + const string& description) { + if (v1 == v2) return tensorflow::Status::OK(); + return tensorflow::errors::InvalidArgument(absl::StrCat( + "Unexpected ", description, ": got ", v1, ", expected ", v2)); +} + ArrayDataType ConvertDataType(tensorflow::DataType dtype) { if (dtype == DT_UINT8) return ArrayDataType::kUint8; @@ -148,9 +182,10 @@ ArrayDataType ConvertDataType(tensorflow::DataType dtype) { return ArrayDataType::kNone; } -Status ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField< - tensorflow::TensorShapeProto_Dim>& input_dims, - int* input_flat_size, Shape* shape) { +tensorflow::Status ImportShape( + const TFLITE_PROTO_NS::RepeatedPtrField& + input_dims, + int* input_flat_size, Shape* shape) { std::vector input_dims_only_sizes; for (auto& d : input_dims) { if (d.size() == 0) { @@ -160,23 +195,24 @@ Status ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField< // For now, tweaking this to record a 0-D shape instead. shape->mutable_dims()->clear(); if (input_flat_size != nullptr) *input_flat_size = 0; - return Status::OK(); + return tensorflow::Status::OK(); } // TensorFlow's shapes use int64s, while TOCO uses ints. if (d.size() > std::numeric_limits::max()) { - return Status(false, "Shape element overflows"); + return tensorflow::errors::InvalidArgument("Shape element overflows"); } input_dims_only_sizes.push_back(d.size()); } *shape->mutable_dims() = input_dims_only_sizes; - if (input_flat_size == nullptr) return Status::OK(); + if (input_flat_size == nullptr) return tensorflow::Status::OK(); return NumElements(input_dims_only_sizes, input_flat_size); } -Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportFloatArray(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_FLOAT); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -203,18 +239,18 @@ Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast(output_float_data.data())); } else { - return Status( - false, + return tensorflow::errors::InvalidArgument( absl::StrCat("Neither input_content (", input_tensor.tensor_content().size() / sizeof(float), ") nor float_val (", input_tensor.float_val_size(), ") have the right dimensions (", input_flat_size, ") for this float tensor")); } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_QUINT8); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -236,18 +272,18 @@ Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast(output_int_data.data())); } else { - return Status( - false, + return tensorflow::errors::InvalidArgument( absl::StrCat("Neither input_content (", input_tensor.tensor_content().size() / sizeof(uint8_t), ") nor int_val (", input_tensor.int_val_size(), ") have the right dimensions (", input_flat_size, ") for this uint8 tensor")); } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportInt32Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_INT32); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -269,18 +305,17 @@ Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast(output_int_data.data())); } else { - return Status( - false, - absl::StrCat("Neither input_content (", - input_tensor.tensor_content().size() / sizeof(int32), - ") nor int_val (", input_tensor.int_val_size(), - ") have the right dimensions (", input_flat_size, - ") for this int32 tensor")); + return tensorflow::errors::InvalidArgument(absl::StrCat( + "Neither input_content (", + input_tensor.tensor_content().size() / sizeof(int32), ") nor int_val (", + input_tensor.int_val_size(), ") have the right dimensions (", + input_flat_size, ") for this int32 tensor")); } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportInt64Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_INT64); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -302,18 +337,18 @@ Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast(output_int_data.data())); } else { - return Status( - false, + return tensorflow::errors::InvalidArgument( absl::StrCat("Neither input_content (", input_tensor.tensor_content().size() / sizeof(int64), ") nor int64_val (", input_tensor.int64_val_size(), ") have the right dimensions (", input_flat_size, ") for this int64 tensor")); } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportBoolArray(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_BOOL); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -343,19 +378,19 @@ Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) { // So far only encountered that in an array with 1 entry, let's // require that until we encounter a graph where that's not the case. if (output_bool_data.size() != 1) { - return Status( - false, absl::StrCat("Neither input_content (", - input_tensor.tensor_content().size(), - ") nor bool_val (", input_tensor.bool_val_size(), - ") have the right dimensions (", input_flat_size, - ") for this bool tensor")); + return tensorflow::errors::InvalidArgument(absl::StrCat( + "Neither input_content (", input_tensor.tensor_content().size(), + ") nor bool_val (", input_tensor.bool_val_size(), + ") have the right dimensions (", input_flat_size, + ") for this bool tensor")); } output_bool_data[0] = false; } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportStringArray(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_STRING); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -365,9 +400,9 @@ Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) { if (!status.ok()) return status; if (input_flat_size != input_tensor.string_val_size()) { - return Status(false, - "Input_content string_val doesn't have the right dimensions " - "for this string tensor"); + return tensorflow::errors::InvalidArgument( + "Input_content string_val doesn't have the right dimensions " + "for this string tensor"); } auto& output_string_data = @@ -377,7 +412,7 @@ Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) { for (int i = 0; i < input_flat_size; ++i) { output_string_data[i] = input_tensor.string_val(i); } - return Status::OK(); + return tensorflow::Status::OK(); } // Count the number of inputs of a given node. If @@ -391,18 +426,19 @@ int GetInputsCount(const NodeDef& node, return i; } } - return node.input_size(); - } else { - return node.input_size(); } + return node.input_size(); } -void CheckInputsCount(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - int expected_input_count) { - QCHECK_EQ(GetInputsCount(node, tf_import_flags), expected_input_count) - << node.op() << " node expects " << expected_input_count - << " input(s) other than control dependencies: " << node.DebugString(); +tensorflow::Status CheckInputsCount( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + int expected_input_count) { + if (GetInputsCount(node, tf_import_flags) != expected_input_count) { + return tensorflow::errors::FailedPrecondition( + node.op(), " node expects ", expected_input_count, + " input(s) other than control dependencies: ", node.DebugString()); + } + return tensorflow::Status::OK(); } template @@ -417,14 +453,14 @@ string CreateConstArray(Model* model, string const& name, return array_name; } -Status ConvertConstOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertConstOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Const"); const auto& tensor = GetTensorAttr(node, "value"); const auto dtype = GetDataTypeAttr(node, "dtype"); - Status status = Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); auto& array = model->GetOrCreateArray(node.name()); switch (dtype) { @@ -460,24 +496,21 @@ Status ConvertConstOperator(const NodeDef& node, array.GetMutableBuffer(); break; } - if (!status.ok()) { - status.AppendMessage(" (while processing node '" + node.name() + "')"); - } - return status; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + status, " (while processing node '" + node.name() + "')"); + return tensorflow::Status::OK(); } -void ConvertConvOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertConvOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Conv2D"); - CheckInputsCount(node, tf_import_flags, 2); + TF_RETURN_IF_ERROR(CheckInputsCount(node, tf_import_flags, 2)); // We only support NHWC, which is the default data_format. // So if data_format is not defined, we're all good. - if (HasAttr(node, "data_format")) { - CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC"); - } - CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); + TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "data_format", "NHWC")); + TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "T", DT_FLOAT)); const auto& input_name = node.input(0); const auto& weights_name = node.input(1); @@ -502,27 +535,26 @@ void ConvertConvOperator(const NodeDef& node, auto* conv = new ConvOperator; conv->inputs = {input_name, reordered_weights_name}; conv->outputs = {node.name()}; + if (!HasAttr(node, "strides")) { + return tensorflow::errors::InvalidArgument("Missing attribute 'strides'"); + } const auto& strides = GetListAttr(node, "strides"); - CHECK_EQ(strides.i_size(), 4); - CHECK_EQ(strides.i(0), 1); - CHECK_EQ(strides.i(3), 1); + TF_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides")); + TF_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)")); + TF_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)")); conv->stride_height = strides.i(1); conv->stride_width = strides.i(2); if (HasAttr(node, "dilations")) { const auto& dilations = GetListAttr(node, "dilations"); - CHECK_EQ(dilations.i_size(), 4); - CHECK_EQ(dilations.i(0), 1) - << "Can only import Conv ops with dilation along the height (1st) or " - "width (2nd) axis. TensorFlow op \"" - << node.name() << "\" had dilations:[ " << dilations.i(0) << ", " - << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3) - << "]."; - CHECK_EQ(dilations.i(3), 1) - << "Can only import Conv ops with dilation along the height (1st) or " - "width (2nd) axis. TensorFlow op \"" - << node.name() << "\" had dilations:[ " << dilations.i(0) << ", " - << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3) - << "]."; + TF_RETURN_IF_ERROR( + ExpectValue(dilations.i_size(), 4, "number of dilations")); + if (dilations.i(0) != 1 || dilations.i(3) != 1) { + return tensorflow::errors::InvalidArgument(absl::StrCat( + "Can only import Conv ops with dilation along the height " + "(1st) or width (2nd) axis. TensorFlow op \"", + node.name(), "\" had dilations:[ ", dilations.i(0), ", ", + dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "].")); + } conv->dilation_height_factor = dilations.i(1); conv->dilation_width_factor = dilations.i(2); } else { @@ -535,16 +567,19 @@ void ConvertConvOperator(const NodeDef& node, } else if (padding == "VALID") { conv->padding.type = PaddingType::kValid; } else { - LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; + return tensorflow::errors::InvalidArgument( + "Bad padding (only SAME and VALID are supported)"); } model->operators.emplace_back(conv); + + return tensorflow::Status::OK(); } -void ConvertDepthwiseConvOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertDepthwiseConvOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "DepthwiseConv2dNative"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); // We only support NHWC, which is the default data_format. // So if data_format is not defined, we're all good. @@ -591,13 +626,14 @@ void ConvertDepthwiseConvOperator(const NodeDef& node, LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; } model->operators.emplace_back(conv); + return tensorflow::Status::OK(); } -void ConvertDepthToSpaceOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertDepthToSpaceOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "DepthToSpace"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); auto* op = new DepthToSpaceOperator; @@ -606,28 +642,37 @@ void ConvertDepthToSpaceOperator(const NodeDef& node, op->block_size = GetIntAttr(node, "block_size"); QCHECK_GE(op->block_size, 2); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertSpaceToDepthOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSpaceToDepthOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "SpaceToDepth"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); - CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); + tensorflow::DataType dtype = GetDataTypeAttr(node, "T"); + if (dtype != DT_FLOAT && dtype != DT_UINT8 && dtype != DT_INT32 && + dtype != DT_INT64) { + const auto* enum_descriptor = tensorflow::DataType_descriptor(); + LOG(FATAL) << "TFLite does not support SpaceToDepth with type T:" + << enum_descriptor->FindValueByNumber(dtype)->name() << ". " + << "T must be one of {DT_FLOAT, DT_INT8, DT_INT32, DT_INT64}."; + } auto* op = new SpaceToDepthOperator; op->inputs.push_back(node.input(0)); op->outputs.push_back(node.name()); op->block_size = GetIntAttr(node, "block_size"); QCHECK_GE(op->block_size, 2); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertBiasAddOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertBiasAddOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "BiasAdd"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); const auto& input_name = node.input(0); const auto& bias_name = node.input(1); @@ -637,13 +682,14 @@ void ConvertBiasAddOperator(const NodeDef& node, biasadd->inputs.push_back(bias_name); biasadd->outputs.push_back(node.name()); model->operators.emplace_back(biasadd); + return tensorflow::Status::OK(); } -void ConvertRandomUniform(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertRandomUniform( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "RandomUniform"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); CHECK_EQ(GetDataTypeAttr(node, "T"), DT_INT32); auto op = absl::make_unique(); @@ -654,11 +700,12 @@ void ConvertRandomUniform(const NodeDef& node, op->seed2 = GetIntAttr(node, "seed2"); CHECK(model != nullptr); model->operators.emplace_back(std::move(op)); + return tensorflow::Status::OK(); } -void ConvertIdentityOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertIdentityOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" || node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient"); auto* op = new TensorFlowIdentityOperator; @@ -675,13 +722,14 @@ void ConvertIdentityOperator(const NodeDef& node, op->inputs.push_back(input_name); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertFakeQuantWithMinMaxArgs( +tensorflow::Status ConvertFakeQuantWithMinMaxArgs( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { CHECK_EQ(node.op(), "FakeQuantWithMinMaxArgs"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); auto* op = new FakeQuantOperator; op->inputs.push_back(node.input(0)); op->minmax.reset(new MinMax); @@ -692,9 +740,10 @@ void ConvertFakeQuantWithMinMaxArgs( // tf.fake_quant_with_min_max_args num_bits defaults to 8. op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8; model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertFakeQuantWithMinMaxVars( +tensorflow::Status ConvertFakeQuantWithMinMaxVars( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { CHECK_EQ(node.op(), "FakeQuantWithMinMaxVars"); @@ -710,14 +759,14 @@ void ConvertFakeQuantWithMinMaxVars( op->outputs.push_back(node.name()); op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8; model->operators.emplace_back(op); + return tensorflow::Status::OK(); } - -void ConvertSqueezeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSqueezeOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Squeeze"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); auto* op = new SqueezeOperator; op->inputs.push_back(node.input(0)); op->outputs.push_back(node.name()); @@ -731,13 +780,14 @@ void ConvertSqueezeOperator(const NodeDef& node, } model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertSumOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSumOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Sum"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new TensorFlowSumOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -746,13 +796,14 @@ void ConvertSumOperator(const NodeDef& node, if (HasAttr(node, "keep_dims")) { op->keep_dims = GetBoolAttr(node, "keep_dims"); } + return tensorflow::Status::OK(); } -void ConvertSplitOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSplitOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Split"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new TensorFlowSplitOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -763,13 +814,14 @@ void ConvertSplitOperator(const NodeDef& node, } op->num_split = num_split; model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertSwitchOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSwitchOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Switch"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new TensorFlowSwitchOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -777,13 +829,14 @@ void ConvertSwitchOperator(const NodeDef& node, // Switch operators have two outputs: "name" and "name:1". op->outputs.push_back(node.name() + ":1"); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertSoftmaxOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSoftmaxOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Softmax"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto& input_name = node.input(0); auto* softmax = new SoftmaxOperator; softmax->inputs.push_back(input_name); @@ -792,13 +845,14 @@ void ConvertSoftmaxOperator(const NodeDef& node, CHECK(!node.attr().count("beta")); // Stab in the dark, just in case. softmax->beta = 1.f; model->operators.emplace_back(softmax); + return tensorflow::Status::OK(); } -void ConvertLRNOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertLRNOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "LRN"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto& input_name = node.input(0); auto* lrn = new LocalResponseNormalizationOperator; lrn->inputs.push_back(input_name); @@ -808,13 +862,14 @@ void ConvertLRNOperator(const NodeDef& node, lrn->alpha = GetFloatAttr(node, "alpha"); lrn->beta = GetFloatAttr(node, "beta"); model->operators.emplace_back(lrn); + return tensorflow::Status::OK(); } -void ConvertMaxPoolOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertMaxPoolOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "MaxPool"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto& input_name = node.input(0); // We only support NHWC, which is the default data_format. // So if data_format is not defined, we're all good. @@ -850,13 +905,14 @@ void ConvertMaxPoolOperator(const NodeDef& node, LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; } model->operators.emplace_back(maxpool); + return tensorflow::Status::OK(); } -void ConvertAvgPoolOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertAvgPoolOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "AvgPool"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto& input_name = node.input(0); // We only support NHWC, which is the default data_format. // So if data_format is not defined, we're all good. @@ -888,13 +944,13 @@ void ConvertAvgPoolOperator(const NodeDef& node, LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; } model->operators.emplace_back(avgpool); + return tensorflow::Status::OK(); } - -void ConvertBatchMatMulOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CheckInputsCount(node, tf_import_flags, 2); +tensorflow::Status ConvertBatchMatMulOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); // https://www.tensorflow.org/versions/r0.12/api_docs/python/math_ops/matrix_math_functions CHECK(!HasAttr(node, "adj_a") || (GetBoolAttr(node, "adj_a") == false)); @@ -904,12 +960,13 @@ void ConvertBatchMatMulOperator(const NodeDef& node, batch_matmul->inputs = {node.input(0), node.input(1)}; batch_matmul->outputs = {node.name()}; model->operators.emplace_back(batch_matmul); + return tensorflow::Status::OK(); } -void ConvertMatMulOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CheckInputsCount(node, tf_import_flags, 2); +tensorflow::Status ConvertMatMulOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); // Transpose flags should be easy to support, but we don't have a // GraphDef with them to test on at the moment. @@ -926,11 +983,12 @@ void ConvertMatMulOperator(const NodeDef& node, matmul->inputs = {node.input(0), node.input(1)}; matmul->outputs = {node.name()}; model->operators.emplace_back(matmul); + return tensorflow::Status::OK(); } -void ConvertConcatOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertConcatOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { Operator* op = nullptr; if (node.op() == "Concat") { op = new TensorFlowConcatOperator; @@ -950,13 +1008,14 @@ void ConvertConcatOperator(const NodeDef& node, } op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } // This method supports simple operators without additional attributes. template -void ConvertSimpleOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSimpleOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { auto* op = new Op; const int num_inputs = GetInputsCount(node, tf_import_flags); for (int i = 0; i < num_inputs; ++i) { @@ -964,22 +1023,23 @@ void ConvertSimpleOperator(const NodeDef& node, } op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } // This method supports simple operators without additional attributes. template -void ConvertSimpleOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CheckInputsCount(node, tf_import_flags, NumInputs); - ConvertSimpleOperator(node, tf_import_flags, model); +tensorflow::Status ConvertSimpleOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, NumInputs)); + return ConvertSimpleOperator(node, tf_import_flags, model); } -void ConvertMaxOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertMaxOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Max"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new TensorFlowMaxOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -988,13 +1048,14 @@ void ConvertMaxOperator(const NodeDef& node, if (HasAttr(node, "keep_dims")) { op->keep_dims = GetBoolAttr(node, "keep_dims"); } + return tensorflow::Status::OK(); } -void ConvertMinOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertMinOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Min"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new TensorFlowMinOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -1003,12 +1064,12 @@ void ConvertMinOperator(const NodeDef& node, if (HasAttr(node, "keep_dims")) { op->keep_dims = GetBoolAttr(node, "keep_dims"); } + return tensorflow::Status::OK(); } - -void ConvertUnsupportedOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertUnsupportedOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { LOG(INFO) << "Converting unsupported operation: " << node.op(); auto* op = new TensorFlowUnsupportedOperator; const int num_inputs = GetInputsCount(node, tf_import_flags); @@ -1031,15 +1092,16 @@ void ConvertUnsupportedOperator(const NodeDef& node, const auto& output_type = GetDataTypeAttr(node, "Tout"); op->output_data_types.push_back(ConvertDataType(output_type)); } + return tensorflow::Status::OK(); } -void ConvertStridedSliceOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertStridedSliceOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "StridedSlice"); // TODO(soroosh): The 4th input (strides) should be e optional, to be // consistent with TF. - CheckInputsCount(node, tf_import_flags, 4); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4)); auto* op = new StridedSliceOperator; for (const auto& input : node.input()) { @@ -1059,14 +1121,15 @@ void ConvertStridedSliceOperator(const NodeDef& node, : 0; model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertPlaceholderOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertPlaceholderOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK(node.op() == "Placeholder" || node.op() == "LegacyFedInput"); if (node.op() == "Placeholder") { - CheckInputsCount(node, tf_import_flags, 0); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 0)); } auto& array = model->GetOrCreateArray(node.name()); if (node.attr().count("dtype")) { @@ -1091,17 +1154,20 @@ void ConvertPlaceholderOperator(const NodeDef& node, } } } + return tensorflow::Status::OK(); } -void ConvertNoOpOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) {} +tensorflow::Status ConvertNoOpOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + return tensorflow::Status::OK(); +} -void ConvertCastOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertCastOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Cast"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto tf_src_dtype = GetDataTypeAttr(node, "SrcT"); const auto tf_dst_dtype = GetDataTypeAttr(node, "DstT"); auto* op = new CastOperator; @@ -1110,27 +1176,31 @@ void ConvertCastOperator(const NodeDef& node, op->inputs.push_back(node.input(0)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertFloorOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertFloorOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Floor"); - CheckInputsCount(node, tf_import_flags, 1); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); const auto data_type = GetDataTypeAttr(node, "T"); CHECK(data_type == DT_FLOAT); auto* op = new FloorOperator; op->inputs.push_back(node.input(0)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertGatherOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertGatherOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK(node.op() == "Gather" || node.op() == "GatherV2"); - if (node.op() == "Gather") CheckInputsCount(node, tf_import_flags, 2); - if (node.op() == "GatherV2") CheckInputsCount(node, tf_import_flags, 3); + if (node.op() == "Gather") + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); + if (node.op() == "GatherV2") + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); const auto indices_data_type = GetDataTypeAttr(node, "Tindices"); CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64); auto* op = new GatherOperator; @@ -1140,13 +1210,14 @@ void ConvertGatherOperator(const NodeDef& node, // should read it an pass it on to the TF Lite Interpreter. op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertArgMaxOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertArgMaxOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "ArgMax"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); const auto axis_data_type = HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32; const auto output_type = HasAttr(node, "output_type") @@ -1160,13 +1231,14 @@ void ConvertArgMaxOperator(const NodeDef& node, op->inputs.push_back(node.input(1)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertResizeBilinearOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertResizeBilinearOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "ResizeBilinear"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new ResizeBilinearOperator; op->align_corners = false; @@ -1178,13 +1250,14 @@ void ConvertResizeBilinearOperator(const NodeDef& node, op->inputs.push_back(node.input(1)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertBatchNormWithGlobalNormalizationOperator( +tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { CHECK_EQ(node.op(), "BatchNormWithGlobalNormalization"); - CheckInputsCount(node, tf_import_flags, 5); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5)); // TODO(ahentz): to really match tensorflow we need to add variance_epsilon // to the input, before feeding it into TensorFlowRsqrtOperator. @@ -1227,13 +1300,14 @@ void ConvertBatchNormWithGlobalNormalizationOperator( op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertFusedBatchNormOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertFusedBatchNormOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "FusedBatchNorm"); - CheckInputsCount(node, tf_import_flags, 5); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5)); // Declare shortcuts for the inputs. const string& gamma_input = node.input(1); @@ -1279,13 +1353,14 @@ void ConvertFusedBatchNormOperator(const NodeDef& node, op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertSpaceToBatchNDOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSpaceToBatchNDOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "SpaceToBatchND"); - CheckInputsCount(node, tf_import_flags, 3); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32); CHECK_EQ(GetDataTypeAttr(node, "Tpaddings"), DT_INT32); auto* op = new SpaceToBatchNDOperator; @@ -1294,13 +1369,14 @@ void ConvertSpaceToBatchNDOperator(const NodeDef& node, op->inputs.push_back(node.input(2)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertBatchToSpaceNDOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertBatchToSpaceNDOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "BatchToSpaceND"); - CheckInputsCount(node, tf_import_flags, 3); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32); CHECK_EQ(GetDataTypeAttr(node, "Tcrops"), DT_INT32); auto* op = new BatchToSpaceNDOperator; @@ -1309,13 +1385,14 @@ void ConvertBatchToSpaceNDOperator(const NodeDef& node, op->inputs.push_back(node.input(2)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertMeanOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertMeanOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Mean"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); auto* op = new MeanOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -1326,11 +1403,12 @@ void ConvertMeanOperator(const NodeDef& node, } else if (HasAttr(node, "keep_dims")) { op->keep_dims = GetBoolAttr(node, "keep_dims"); } + return tensorflow::Status::OK(); } -void ConvertSvdfOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSvdfOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Svdf"); const int input_size = GetInputsCount(node, tf_import_flags); QCHECK(input_size == 3 || input_size == 4) @@ -1353,14 +1431,15 @@ void ConvertSvdfOperator(const NodeDef& node, } op->rank = node.attr().at("Rank").i(); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } // This is just bare bones support to get the shapes to propagate. -void ConvertTransposeConvOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertTransposeConvOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Conv2DBackpropInput"); - CheckInputsCount(node, tf_import_flags, 3); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); auto* op = new TransposeConvOperator; op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); @@ -1401,11 +1480,13 @@ void ConvertTransposeConvOperator(const NodeDef& node, if (existing_transpose) { CHECK(existing_transpose->type == OperatorType::kTranspose); } else { - // Transpose weights from HWIO order to OHWI order, which is more efficient - // for computation + // Transpose weights from HWOI order to OHWI order, which is more efficient + // for computation. (Note that TensorFlow considers the order as HWIO + // because they consider this a backward conv, inverting the sense of + // input/output.) TransposeOperator* transpose = new TransposeOperator; string perm_array = CreateConstArray( - model, node.name() + "_transpose_perm", {3, 0, 1, 2}); + model, node.name() + "_transpose_perm", {2, 0, 1, 3}); transpose->inputs = {weights_name, perm_array}; transpose->outputs = {transposed_weights_name}; model->operators.emplace_back(transpose); @@ -1422,14 +1503,14 @@ void ConvertTransposeConvOperator(const NodeDef& node, "Conv2DBackpropInput nodes."; } model->operators.emplace_back(op); + return tensorflow::Status::OK(); } - -void ConvertRangeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertRangeOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Range"); - CheckInputsCount(node, tf_import_flags, 3); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3)); auto* op = new RangeOperator; if (HasAttr(node, "Tidx")) { const auto dtype = toco::GetDataTypeAttr(node, "Tidx"); @@ -1442,11 +1523,12 @@ void ConvertRangeOperator(const NodeDef& node, op->inputs.push_back(node.input(2)); op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } -void ConvertStackOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertStackOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK((node.op() == "Stack") || (node.op() == "Pack")); auto* op = new StackOperator; const int num_inputs = GetInputsCount(node, tf_import_flags); @@ -1462,9 +1544,9 @@ void ConvertStackOperator(const NodeDef& node, op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0; op->outputs.push_back(node.name()); model->operators.emplace_back(op); + return tensorflow::Status::OK(); } - // Some TensorFlow ops only occur in graph cycles, representing // control flow. We do not currently support control flow, so we wouldn't // be able to fully support such graphs, including performing inference, @@ -1475,7 +1557,7 @@ void ConvertStackOperator(const NodeDef& node, // such ops as RNN back-edges, which is technically incorrect (does not // allow representing the op's semantics) but good enough to get a // graph visualization. -void ConvertOperatorSpecialCasedAsRNNBackEdge( +tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { // At the moment, the only type of operator special-cased in this way is @@ -1488,6 +1570,23 @@ void ConvertOperatorSpecialCasedAsRNNBackEdge( rnn_state->set_discardable(true); rnn_state->set_state_array(node.name()); rnn_state->set_back_edge_source_array(node.input(0)); + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertShapeOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "Shape"); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); + const auto out_type = + HasAttr(node, "out_type") ? GetDataTypeAttr(node, "out_type") : DT_INT32; + CHECK(out_type == DT_INT64 || out_type == DT_INT32); + auto op = absl::make_unique(); + op->output_data_type = ConvertDataType(out_type); + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + model->operators.push_back(std::move(op)); + return tensorflow::Status::OK(); } void StripCaretFromArrayNames(Model* model) { @@ -1630,9 +1729,9 @@ bool InlineAllFunctions(GraphDef* graphdef) { return graph_modified; } -void ConvertTopKV2Operator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertTopKV2Operator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK((node.op() == "TopK") || (node.op() == "TopKV2")); auto op = absl::make_unique(); op->inputs.push_back(node.input(0)); @@ -1642,22 +1741,23 @@ void ConvertTopKV2Operator(const NodeDef& node, model, node.name() + "k", {static_cast(GetIntAttr(node, "k"))}); op->inputs.push_back(k_array); } else { - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); op->inputs.push_back(node.input(1)); } // The op has two outputs. op->outputs.push_back(node.name()); op->outputs.push_back(node.name() + ":1"); model->operators.emplace_back(op.release()); + return tensorflow::Status::OK(); } -void ConvertDynamicPartitionOperator( +tensorflow::Status ConvertDynamicPartitionOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { auto op = absl::make_unique(); CHECK(HasAttr(node, "num_partitions")); op->num_partitions = GetIntAttr(node, "num_partitions"); - CheckInputsCount(node, tf_import_flags, 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); op->inputs.push_back(node.input(0)); op->inputs.push_back(node.input(1)); CHECK_GT(op->num_partitions, 1); @@ -1666,11 +1766,12 @@ void ConvertDynamicPartitionOperator( op->outputs.push_back(node.name() + ":" + std::to_string(i)); } model->operators.emplace_back(op.release()); + return tensorflow::Status::OK(); } -void ConvertDynamicStitchOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertDynamicStitchOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { // The parallel and non-parallel variants are the same besides whether they // have a parallel loop; there are no behavioral differences. CHECK(node.op() == "DynamicStitch" || node.op() == "ParallelDynamicStitch"); @@ -1678,19 +1779,20 @@ void ConvertDynamicStitchOperator(const NodeDef& node, CHECK(HasAttr(node, "N")); op->num_partitions = GetIntAttr(node, "N"); // Expect all ID partitions + all value partitions. - CheckInputsCount(node, tf_import_flags, op->num_partitions * 2); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, op->num_partitions * 2)); for (int i = 0; i < op->num_partitions * 2; ++i) { op->inputs.push_back(node.input(i)); } op->outputs.push_back(node.name()); model->operators.emplace_back(op.release()); + return tensorflow::Status::OK(); } -void ConvertSparseToDenseOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertSparseToDenseOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "SparseToDense"); - CheckInputsCount(node, tf_import_flags, 4); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4)); auto* op = new SparseToDenseOperator; for (const string& input : node.input()) { @@ -1702,209 +1804,132 @@ void ConvertSparseToDenseOperator(const NodeDef& node, ? GetBoolAttr(node, "validate_indices") : true; model->operators.emplace_back(op); + return tensorflow::Status::OK(); } } // namespace namespace internal { -Status ImportTensorFlowNode(const tensorflow::NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - // TODO(ahentz): Historically these functions all CHECK-fail on error. We've - // been slowly converting them to return Status. - if (node.op() == "Const") { - return ConvertConstOperator(node, tf_import_flags, model); - } else if (node.op() == "Conv2D") { - ConvertConvOperator(node, tf_import_flags, model); - } else if (node.op() == "Conv2DBackpropInput") { - ConvertTransposeConvOperator(node, tf_import_flags, model); - } else if (node.op() == "DepthwiseConv2dNative") { - ConvertDepthwiseConvOperator(node, tf_import_flags, model); - } else if (node.op() == "DepthToSpace") { - ConvertDepthToSpaceOperator(node, tf_import_flags, model); - } else if (node.op() == "SpaceToDepth") { - ConvertSpaceToDepthOperator(node, tf_import_flags, model); - } else if (node.op() == "BiasAdd") { - ConvertBiasAddOperator(node, tf_import_flags, model); - } else if (node.op() == "Relu") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Relu6") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Sigmoid") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Tanh") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "MaxPool") { - ConvertMaxPoolOperator(node, tf_import_flags, model); - } else if (node.op() == "AvgPool") { - ConvertAvgPoolOperator(node, tf_import_flags, model); - } else if (node.op() == "Reshape") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "BatchMatMul") { - ConvertBatchMatMulOperator(node, tf_import_flags, model); - } else if (node.op() == "MatMul") { - ConvertMatMulOperator(node, tf_import_flags, model); - } else if (node.op() == "Div" || node.op() == "RealDiv") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Identity" || node.op() == "CheckNumerics" || - node.op() == "StopGradient") { - ConvertIdentityOperator(node, tf_import_flags, model); - } else if (node.op() == "FakeQuantWithMinMaxVars") { - ConvertFakeQuantWithMinMaxVars(node, tf_import_flags, model); - } else if (node.op() == "FakeQuantWithMinMaxArgs") { - ConvertFakeQuantWithMinMaxArgs(node, tf_import_flags, model); - } else if (node.op() == "Neg") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Rsqrt") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Squeeze") { - ConvertSqueezeOperator(node, tf_import_flags, model); - } else if (node.op() == "Sqrt") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Square") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Add") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "AddN") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Mul") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Sub") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Sum") { - ConvertSumOperator(node, tf_import_flags, model); - } else if (node.op() == "Tile") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Concat" || node.op() == "ConcatV2") { - ConvertConcatOperator(node, tf_import_flags, model); - } else if (node.op() == "LRN") { - ConvertLRNOperator(node, tf_import_flags, model); - } else if (node.op() == "Softmax") { - ConvertSoftmaxOperator(node, tf_import_flags, model); - } else if (node.op() == "Log") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "LogSoftmax") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "All") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Assert") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Less") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "LessEqual") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Greater") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "GreaterEqual") { - ConvertSimpleOperator( - node, tf_import_flags, model); - } else if (node.op() == "Max") { - ConvertMaxOperator(node, tf_import_flags, model); - } else if (node.op() == "Min") { - ConvertMinOperator(node, tf_import_flags, model); - } else if (node.op() == "Maximum") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Minimum") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Merge") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Pad") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "PadV2") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "StridedSlice") { - ConvertStridedSliceOperator(node, tf_import_flags, model); - } else if (node.op() == "Shape") { - ConvertSimpleOperator(node, tf_import_flags, - model); - } else if (node.op() == "Slice") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Split") { - ConvertSplitOperator(node, tf_import_flags, model); - } else if (node.op() == "Switch") { - ConvertSwitchOperator(node, tf_import_flags, model); - } else if (node.op() == "Placeholder") { - ConvertPlaceholderOperator(node, tf_import_flags, model); - } else if (node.op() == "PlaceholderWithDefault") { - ConvertIdentityOperator(node, tf_import_flags, model); - } else if (node.op() == "LegacyFedInput") { - ConvertPlaceholderOperator(node, tf_import_flags, model); - } else if (node.op() == "NoOp") { - ConvertNoOpOperator(node, tf_import_flags, model); - } else if (node.op() == "Cast") { - ConvertCastOperator(node, tf_import_flags, model); - } else if (node.op() == "Floor") { - ConvertFloorOperator(node, tf_import_flags, model); - } else if (node.op() == "Gather" || node.op() == "GatherV2") { - ConvertGatherOperator(node, tf_import_flags, model); - } else if (node.op() == "ResizeBilinear") { - ConvertResizeBilinearOperator(node, tf_import_flags, model); - } else if (node.op() == "BatchNormWithGlobalNormalization") { - ConvertBatchNormWithGlobalNormalizationOperator(node, tf_import_flags, - model); - } else if (node.op() == "FusedBatchNorm") { - ConvertFusedBatchNormOperator(node, tf_import_flags, model); - } else if (node.op() == "SpaceToBatchND") { - ConvertSpaceToBatchNDOperator(node, tf_import_flags, model); - } else if (node.op() == "BatchToSpaceND") { - ConvertBatchToSpaceNDOperator(node, tf_import_flags, model); - } else if (node.op() == "Mean") { - ConvertMeanOperator(node, tf_import_flags, model); - } else if (node.op() == "Svdf") { - ConvertSvdfOperator(node, tf_import_flags, model); - } else if (node.op() == "NextIteration") { - ConvertOperatorSpecialCasedAsRNNBackEdge(node, tf_import_flags, model); - } else if (node.op() == "ExpandDims") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Fill") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "FloorDiv") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "FloorMod") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Range") { - ConvertRangeOperator(node, tf_import_flags, model); - } else if (node.op() == "Rank") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Stack" || node.op() == "Pack") { - ConvertStackOperator(node, tf_import_flags, model); - } else if (node.op() == "Transpose") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "ArgMax") { - ConvertArgMaxOperator(node, tf_import_flags, model); - } else if (node.op() == "Exp") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "TopK" || node.op() == "TopKV2") { - ConvertTopKV2Operator(node, tf_import_flags, model); - } else if (node.op() == "DynamicPartition") { - ConvertDynamicPartitionOperator(node, tf_import_flags, model); - } else if (node.op() == "DynamicStitch" || - node.op() == "ParallelDynamicStitch") { - ConvertDynamicStitchOperator(node, tf_import_flags, model); - } else if (node.op() == "RandomUniform") { - ConvertRandomUniform(node, tf_import_flags, model); - } else if (node.op() == "Sin") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "Select") { - ConvertSimpleOperator(node, tf_import_flags, model); - } else if (node.op() == "SparseToDense") { - ConvertSparseToDenseOperator(node, tf_import_flags, model); + +using ConverterType = tensorflow::Status (*)( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model); +using ConverterMapType = std::unordered_map; + +ConverterMapType GetTensorFlowNodeConverterMap() { + return std::unordered_map({ + {"Add", ConvertSimpleOperator}, + {"AddN", ConvertSimpleOperator}, + {"All", ConvertSimpleOperator}, + {"ArgMax", ConvertArgMaxOperator}, + {"Assert", ConvertSimpleOperator}, + {"AvgPool", ConvertAvgPoolOperator}, + {"BatchMatMul", ConvertBatchMatMulOperator}, + {"BatchNormWithGlobalNormalization", + ConvertBatchNormWithGlobalNormalizationOperator}, + {"BatchToSpaceND", ConvertBatchToSpaceNDOperator}, + {"BiasAdd", ConvertBiasAddOperator}, + {"Cast", ConvertCastOperator}, + {"CheckNumerics", ConvertIdentityOperator}, + {"Concat", ConvertConcatOperator}, + {"ConcatV2", ConvertConcatOperator}, + {"Const", ConvertConstOperator}, + {"Conv2D", ConvertConvOperator}, + {"Conv2DBackpropInput", ConvertTransposeConvOperator}, + {"DepthToSpace", ConvertDepthToSpaceOperator}, + {"DepthwiseConv2dNative", ConvertDepthwiseConvOperator}, + {"Div", ConvertSimpleOperator}, + {"DynamicPartition", ConvertDynamicPartitionOperator}, + {"DynamicStitch", ConvertDynamicStitchOperator}, + {"Equal", ConvertSimpleOperator}, + {"Exp", ConvertSimpleOperator}, + {"ExpandDims", ConvertSimpleOperator}, + {"FakeQuantWithMinMaxArgs", ConvertFakeQuantWithMinMaxArgs}, + {"FakeQuantWithMinMaxVars", ConvertFakeQuantWithMinMaxVars}, + {"Fill", ConvertSimpleOperator}, + {"Floor", ConvertFloorOperator}, + {"FloorDiv", ConvertSimpleOperator}, + {"FloorMod", ConvertSimpleOperator}, + {"FusedBatchNorm", ConvertFusedBatchNormOperator}, + {"Gather", ConvertGatherOperator}, + {"GatherV2", ConvertGatherOperator}, + {"Greater", ConvertSimpleOperator}, + {"GreaterEqual", + ConvertSimpleOperator}, + {"Identity", ConvertIdentityOperator}, + {"LRN", ConvertLRNOperator}, + {"LegacyFedInput", ConvertPlaceholderOperator}, + {"Less", ConvertSimpleOperator}, + {"LessEqual", ConvertSimpleOperator}, + {"Log", ConvertSimpleOperator}, + {"Log", ConvertSimpleOperator}, + {"LogSoftmax", ConvertSimpleOperator}, + {"MatMul", ConvertMatMulOperator}, + {"Max", ConvertMaxOperator}, + {"MaxPool", ConvertMaxPoolOperator}, + {"Maximum", ConvertSimpleOperator}, + {"Mean", ConvertMeanOperator}, + {"Merge", ConvertSimpleOperator}, + {"Min", ConvertMinOperator}, + {"Minimum", ConvertSimpleOperator}, + {"Mul", ConvertSimpleOperator}, + {"Neg", ConvertSimpleOperator}, + {"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge}, + {"NoOp", ConvertNoOpOperator}, + {"NotEqual", ConvertSimpleOperator}, + {"Pack", ConvertStackOperator}, + {"Pad", ConvertSimpleOperator}, + {"PadV2", ConvertSimpleOperator}, + {"ParallelDynamicStitch", ConvertDynamicStitchOperator}, + {"Placeholder", ConvertPlaceholderOperator}, + {"PlaceholderWithDefault", ConvertIdentityOperator}, + {"RandomUniform", ConvertRandomUniform}, + {"Range", ConvertRangeOperator}, + {"Rank", ConvertSimpleOperator}, + {"RealDiv", ConvertSimpleOperator}, + {"Relu", ConvertSimpleOperator}, + {"Relu6", ConvertSimpleOperator}, + {"Reshape", ConvertSimpleOperator}, + {"ResizeBilinear", ConvertResizeBilinearOperator}, + {"Rsqrt", ConvertSimpleOperator}, + {"Select", ConvertSimpleOperator}, + {"Shape", ConvertShapeOperator}, + {"Sigmoid", ConvertSimpleOperator}, + {"Sin", ConvertSimpleOperator}, + {"Slice", ConvertSimpleOperator}, + {"Softmax", ConvertSoftmaxOperator}, + {"SpaceToBatchND", ConvertSpaceToBatchNDOperator}, + {"SpaceToDepth", ConvertSpaceToDepthOperator}, + {"SparseToDense", ConvertSparseToDenseOperator}, + {"Split", ConvertSplitOperator}, + {"Sqrt", ConvertSimpleOperator}, + {"Square", ConvertSimpleOperator}, + {"Squeeze", ConvertSqueezeOperator}, + {"Stack", ConvertStackOperator}, + {"StopGradient", ConvertIdentityOperator}, + {"StridedSlice", ConvertStridedSliceOperator}, + {"Sub", ConvertSimpleOperator}, + {"Sum", ConvertSumOperator}, + {"Svdf", ConvertSvdfOperator}, + {"Switch", ConvertSwitchOperator}, + {"Tanh", ConvertSimpleOperator}, + {"Tile", ConvertSimpleOperator}, + {"TopK", ConvertTopKV2Operator}, + {"TopKV2", ConvertTopKV2Operator}, + {"Transpose", ConvertSimpleOperator}, + }); +} + +tensorflow::Status ImportTensorFlowNode( + const tensorflow::NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, Model* model, + const ConverterMapType& converter_map) { + auto converter = converter_map.find(node.op()); + if (converter == converter_map.end()) { + return ConvertUnsupportedOperator(node, tf_import_flags, model); } else { - ConvertUnsupportedOperator(node, tf_import_flags, model); + return converter->second(node, tf_import_flags, model); } - return Status::OK(); } } // namespace internal @@ -1930,10 +1955,13 @@ std::unique_ptr ImportTensorFlowGraphDef( } Model* model = new Model; + const internal::ConverterMapType& converter_map = + internal::GetTensorFlowNodeConverterMap(); for (auto node : inlined_graph.node()) { StripZeroOutputIndexFromInputs(&node); - auto status = internal::ImportTensorFlowNode(node, tf_import_flags, model); + auto status = internal::ImportTensorFlowNode(node, tf_import_flags, model, + converter_map); CHECK(status.ok()) << status.error_message(); } diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc index 835676662b9cb7ed20e578e2a35747a64ba443dc..90e6f698efee6a6a32da18a658e72c3e8b6550c0 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc @@ -21,10 +21,10 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/lib/core/status.h" namespace toco { -using port::Status; using tensorflow::AttrValue; using tensorflow::DT_BOOL; using tensorflow::DT_FLOAT; @@ -33,10 +33,17 @@ using tensorflow::DT_INT64; using tensorflow::DT_QUINT8; using tensorflow::DT_STRING; using tensorflow::NodeDef; +using tensorflow::Status; namespace internal { +using ConverterType = tensorflow::Status (*)( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model); +using ConverterMapType = std::unordered_map; + +ConverterMapType GetTensorFlowNodeConverterMap(); Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&, - Model*); + Model*, const ConverterMapType&); } // namespace internal namespace { @@ -104,8 +111,9 @@ class ShapeImportTest : public ::testing::TestWithParam { Status ImportNode(const NodeDef& node) { Model model; - return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), - &model); + const auto converter = internal::GetTensorFlowNodeConverterMap(); + return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), &model, + converter); } }; @@ -117,9 +125,10 @@ TEST_P(ShapeImportTest, ShapeElementIsNegative) { NodeDef node; BuildConstNode({1, -2, 10}, GetParam(), 0, &node); auto status = ImportNode(node); - EXPECT_EQ(status.error_message(), - "Tensor shape should not include negative values (while processing " - "node 'Node1')"); + EXPECT_EQ( + status.error_message(), + "Tensor shape should not include negative values\n\t (while processing " + "node 'Node1')"); } INSTANTIATE_TEST_CASE_P(ShapeElementIsNegative, ShapeImportTest, ::testing::ValuesIn(TestTypes())); @@ -129,7 +138,7 @@ TEST_P(ShapeImportTest, ShapeElementTooLarge) { BuildConstNode({3000000000}, GetParam(), 0, &node); auto status = ImportNode(node); EXPECT_EQ(status.error_message(), - "Shape element overflows (while processing node 'Node1')"); + "Shape element overflows\n\t (while processing node 'Node1')"); } INSTANTIATE_TEST_CASE_P(ShapeElementTooLarge, ShapeImportTest, ::testing::ValuesIn(TestTypes())); @@ -139,7 +148,7 @@ TEST_P(ShapeImportTest, ShapeTooLarge) { BuildConstNode({1000000, 2000000, 2000000, 2000000}, GetParam(), 0, &node); auto status = ImportNode(node); EXPECT_EQ(status.error_message(), - "Tensor shape is too large (while processing node 'Node1')"); + "Tensor shape is too large\n\t (while processing node 'Node1')"); } INSTANTIATE_TEST_CASE_P(ShapeTooLarge, ShapeImportTest, ::testing::ValuesIn(TestTypes())); @@ -148,11 +157,11 @@ TEST_P(ShapeImportTest, ValidShapeButZeroElements) { NodeDef node; BuildConstNode({1, 2, 2, 2}, GetParam(), 0, &node); auto status = ImportNode(node); - EXPECT_THAT( - status.error_message(), - ::testing::MatchesRegex( - "Neither input_content .0. nor .*_val .0. have the right " - "dimensions .8. for this .* tensor .while processing node 'Node1'.")); + EXPECT_THAT(status.error_message(), + ::testing::MatchesRegex( + "Neither input_content .0. nor .*_val .0. have the right " + "dimensions .8. for this .* tensor\n\t .while processing " + "node 'Node1'.")); } INSTANTIATE_TEST_CASE_P(ValidShapeButZeroElements, ShapeImportTest, ::testing::ValuesIn(TestTypes())); diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 1a4f87e36361773a1ab1c3c673b5197aa2bdf6cc..ef170b38840dbc0193843251f8e59d6d075ba7e2 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -32,7 +32,7 @@ namespace toco { using tflite::QuantizationParams; -enum class OperatorType { +enum class OperatorType : uint8 { kNone, // General-purpose neural network operators. kAdd, @@ -96,38 +96,38 @@ enum class OperatorType { // Special operators used for importing TensorFlow nodes. // The general intent is to have some graph transformation either // drop them or rewrite them as general-purpose operators. - kTensorFlowAll, - kTensorFlowAssert, - kTensorFlowConcat, - kTensorFlowConcatV2, - kTensorFlowGreater, - kTensorFlowGreaterEqual, - kTensorFlowIdentity, - kTensorFlowLess, - kTensorFlowLessEqual, - kTensorFlowMax, - kTensorFlowMaximum, - kTensorFlowMin, - kTensorFlowMinimum, - kTensorFlowMatMul, - kTensorFlowMerge, + kAll, + kAssert, + kConcat, + kConcatV2, + kGreater, + kGreaterEqual, + kIdentity, + kLess, + kLessEqual, + kMax, // Reduction Max + kMaximum, // Element-wise Maximum + kMin, // Reduction Min + kMinimum, // Element-wise Minimum + kMatMul, + kMerge, kNeg, - kTensorFlowReshape, - kTensorFlowRsqrt, - kTensorFlowShape, - kTensorFlowSplit, - kTensorFlowSqrt, - kTensorFlowSquare, - kTensorFlowSum, - kTensorFlowSwitch, - kTensorFlowTile, + kReshape, + kRsqrt, + kShape, + kSplit, + kSqrt, + kSquare, + kSum, + kSwitch, + kTile, kTranspose, kTopK_V2, kDynamicPartition, kDynamicStitch, // An unsupported TF operation. It's only needed to be able to represent TF // graph internally and is expected to be dropped by graph transformations. - kTensorFlowUnsupported, + kUnsupported, // Finally, TensorFlow uses different conventions for axes ordering, // see AxesOrder, and this cannot always be resolved at the time of importing // nodes, as TensorFlow parameters may be constant-expression subgraphs @@ -136,6 +136,8 @@ enum class OperatorType { kReorderAxes, kSelect, kSparseToDense, + kEqual, + kNotEqual, }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -153,6 +155,7 @@ enum class AxesOrder { k1HWO, // Our standard for DepthwiseConv weights kHWIM, // TensorFlow DepthwiseConv weights kNHWC, // TensorFlow activations + kHWOI, // TensorFlow back-prop conv weights }; // The type of the scalars in an array. @@ -171,7 +174,7 @@ enum class AxesOrder { // because we'll be dropping the array anyway (e.g. some exotic array types // may be involved only in debug-only subgraphs that we may not be interested // in actually supporting). -enum class ArrayDataType { +enum class ArrayDataType : uint8 { kNone, // 0 kBool, kFloat, @@ -798,7 +801,7 @@ struct DivOperator : Operator { // // TensorFlow equivalent: Identity struct TensorFlowIdentityOperator : Operator { - TensorFlowIdentityOperator() : Operator(OperatorType::kTensorFlowIdentity) {} + TensorFlowIdentityOperator() : Operator(OperatorType::kIdentity) {} }; // Batch matrix multiplication operator. This comes from the (deprecated) @@ -824,7 +827,7 @@ struct BatchMatMulOperator : Operator { // // TensorFlow equivalent: MatMul struct TensorFlowMatMulOperator : Operator { - TensorFlowMatMulOperator() : Operator(OperatorType::kTensorFlowMatMul) {} + TensorFlowMatMulOperator() : Operator(OperatorType::kMatMul) {} }; // Padding operator. Pads a tensor with zeros. @@ -958,7 +961,7 @@ struct StridedSliceOperator : Operator { // TensorFlow equivalent: Reshape --- except that we only support a special case // here, where the output shape is a matrix (2D) shape. struct TensorFlowReshapeOperator : Operator { - TensorFlowReshapeOperator() : Operator(OperatorType::kTensorFlowReshape) {} + TensorFlowReshapeOperator() : Operator(OperatorType::kReshape) {} std::vector shape; }; @@ -1128,7 +1131,7 @@ struct SelectOperator : Operator { // // TensorFlow equivalent: Rsqrt struct TensorFlowRsqrtOperator : Operator { - TensorFlowRsqrtOperator() : Operator(OperatorType::kTensorFlowRsqrt) {} + TensorFlowRsqrtOperator() : Operator(OperatorType::kRsqrt) {} }; // Stacks a list of rank-R tensors into one rank-(R+1) tensor. @@ -1154,10 +1157,10 @@ struct StackOperator : Operator { // This operation outputs a 1-D integer tensor representing the shape of // the input. // -// TensorFlow equivalent: Shape. We currently assume that the output is int32 -// and not int64. The output type could be stored herein. +// TensorFlow equivalent: Shape. struct TensorFlowShapeOperator : Operator { - TensorFlowShapeOperator() : Operator(OperatorType::kTensorFlowShape) {} + TensorFlowShapeOperator() : Operator(OperatorType::kShape) {} + ArrayDataType output_data_type = ArrayDataType::kInt32; }; // Element-wise square-root (x^0.5) operator. @@ -1167,7 +1170,7 @@ struct TensorFlowShapeOperator : Operator { // // TensorFlow equivalent: Sqrt struct TensorFlowSqrtOperator : Operator { - TensorFlowSqrtOperator() : Operator(OperatorType::kTensorFlowSqrt) {} + TensorFlowSqrtOperator() : Operator(OperatorType::kSqrt) {} }; // Element-wise square (x*x) operator. @@ -1177,7 +1180,7 @@ struct TensorFlowSqrtOperator : Operator { // // TensorFlow equivalent: Square struct TensorFlowSquareOperator : Operator { - TensorFlowSquareOperator() : Operator(OperatorType::kTensorFlowSquare) {} + TensorFlowSquareOperator() : Operator(OperatorType::kSquare) {} }; // Transposes a tensor. @@ -1205,24 +1208,24 @@ struct SubOperator : Operator { SubOperator() : Operator(OperatorType::kSub) {} }; -// Global sum reduction: computes the sum of all of entries in the input array. -// Thus the output is "0-dimensional": it consists of a single scalar value. +// Sum reduction: computes the sum of all of entries across the axes. // // Inputs: // inputs[0]: required: the input array // -// TensorFlow equivalent: Sum --- except that we only support the special case -// of global reduction across all dimensions. +// TensorFlow equivalent: Sum struct TensorFlowSumOperator : Operator { - TensorFlowSumOperator() : Operator(OperatorType::kTensorFlowSum) {} + TensorFlowSumOperator() : Operator(OperatorType::kSum) {} bool keep_dims = false; }; // TensorFlow Tile equivalent. Refer to TensorFlow documentation for details. -// Not fully supported, just a placeholder to handle TensorFlow graphs and -// support graph transformations to other operator types by matching sub-graphs. +// +// Inputs: +// inputs[0]: required: the input array +// inputs[1]: required: int array with length of rank(input[0]) struct TensorFlowTileOperator : Operator { - TensorFlowTileOperator() : Operator(OperatorType::kTensorFlowTile) {} + TensorFlowTileOperator() : Operator(OperatorType::kTile) {} }; // TensorFlow Slice equivalent. Refer to TensorFlow documentation for details. @@ -1237,7 +1240,7 @@ struct SliceOperator : Operator { // Not fully supported, just a placeholder to handle TensorFlow graphs and // support graph transformations to other operator types by matching sub-graphs. struct TensorFlowSplitOperator : Operator { - TensorFlowSplitOperator() : Operator(OperatorType::kTensorFlowSplit) {} + TensorFlowSplitOperator() : Operator(OperatorType::kSplit) {} int num_split = 0; }; @@ -1248,7 +1251,7 @@ struct TensorFlowSplitOperator : Operator { // dimension then we can change this op into a DepthConcatenation op. // Otherwise, we hope for some other graph transformation to drop this node. struct TensorFlowConcatOperator : Operator { - TensorFlowConcatOperator() : Operator(OperatorType::kTensorFlowConcat) {} + TensorFlowConcatOperator() : Operator(OperatorType::kConcat) {} }; // TensorFlow ConcatV2 equivalent. Refer to TensorFlow documentation for @@ -1259,7 +1262,7 @@ struct TensorFlowConcatOperator : Operator { // dimension then we can change this op into a DepthConcatenation op. // Otherwise, we hope for some other graph transformation to drop this node. struct TensorFlowConcatV2Operator : Operator { - TensorFlowConcatV2Operator() : Operator(OperatorType::kTensorFlowConcatV2) {} + TensorFlowConcatV2Operator() : Operator(OperatorType::kConcatV2) {} }; // TensorFlow Merge equivalent. Refer to TensorFlow documentation for details. @@ -1275,7 +1278,7 @@ struct TensorFlowConcatV2Operator : Operator { // control flow that can be resolved at tooling time (independently of input // activations). struct TensorFlowMergeOperator : Operator { - TensorFlowMergeOperator() : Operator(OperatorType::kTensorFlowMerge) {} + TensorFlowMergeOperator() : Operator(OperatorType::kMerge) {} }; // TensorFlow Switch equivalent. Refer to TensorFlow documentation for details. @@ -1298,7 +1301,7 @@ struct TensorFlowMergeOperator : Operator { // control flow that can be resolved at tooling time (independently of input // activations). struct TensorFlowSwitchOperator : Operator { - TensorFlowSwitchOperator() : Operator(OperatorType::kTensorFlowSwitch) {} + TensorFlowSwitchOperator() : Operator(OperatorType::kSwitch) {} }; // TensorFlow All equivalent. Refer to TensorFlow documentation for details. @@ -1307,7 +1310,7 @@ struct TensorFlowSwitchOperator : Operator { // Typically, this is only used as an input to an Assert node, so can be // removed as an unused node as we drop Assert nodes. struct TensorFlowAllOperator : Operator { - TensorFlowAllOperator() : Operator(OperatorType::kTensorFlowAll) {} + TensorFlowAllOperator() : Operator(OperatorType::kAll) {} }; // TensorFlow Assert equivalent. Refer to TensorFlow documentation for details. @@ -1315,7 +1318,7 @@ struct TensorFlowAllOperator : Operator { // support graph transformations to other operator types by matching sub-graphs. // Typically, we just drop Assert nodes. struct TensorFlowAssertOperator : Operator { - TensorFlowAssertOperator() : Operator(OperatorType::kTensorFlowAssert) {} + TensorFlowAssertOperator() : Operator(OperatorType::kAssert) {} }; // TensorFlow Less equivalent. Refer to TensorFlow documentation for details. @@ -1324,7 +1327,7 @@ struct TensorFlowAssertOperator : Operator { // Typically, this is only used as an input to an Assert node, so can be // removed as an unused node as we drop Assert nodes. struct TensorFlowLessOperator : Operator { - TensorFlowLessOperator() : Operator(OperatorType::kTensorFlowLess) {} + TensorFlowLessOperator() : Operator(OperatorType::kLess) {} }; // TensorFlow LessEqual equivalent. Refer to TensorFlow documentation for @@ -1334,8 +1337,7 @@ struct TensorFlowLessOperator : Operator { // Typically, this is only used as an input to an Assert node, so can be // removed as an unused node as we drop Assert nodes. struct TensorFlowLessEqualOperator : Operator { - TensorFlowLessEqualOperator() - : Operator(OperatorType::kTensorFlowLessEqual) {} + TensorFlowLessEqualOperator() : Operator(OperatorType::kLessEqual) {} }; // TensorFlow Less equivalent. Refer to TensorFlow documentation for details. @@ -1344,7 +1346,7 @@ struct TensorFlowLessEqualOperator : Operator { // Typically, this is only used as an input to an Assert node, so can be // removed as an unused node as we drop Assert nodes. struct TensorFlowGreaterOperator : Operator { - TensorFlowGreaterOperator() : Operator(OperatorType::kTensorFlowGreater) {} + TensorFlowGreaterOperator() : Operator(OperatorType::kGreater) {} }; // TensorFlow GreaterEqual equivalent. Refer to TensorFlow documentation for @@ -1354,8 +1356,23 @@ struct TensorFlowGreaterOperator : Operator { // Typically, this is only used as an input to an Assert node, so can be // removed as an unused node as we drop Assert nodes. struct TensorFlowGreaterEqualOperator : Operator { - TensorFlowGreaterEqualOperator() - : Operator(OperatorType::kTensorFlowGreaterEqual) {} + TensorFlowGreaterEqualOperator() : Operator(OperatorType::kGreaterEqual) {} +}; + +// TensorFlow Equal equivalent. Refer to TensorFlow documentation for +// details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +// Typically, this is only used as an input to an Assert node, so can be +// removed as an unused node as we drop Assert nodes. +struct TensorFlowEqualOperator : Operator { + TensorFlowEqualOperator() : Operator(OperatorType::kEqual) {} +}; + +// TensorFlow Not Equal equivalent. Refer to TensorFlow documentation for +// details. +struct TensorFlowNotEqualOperator : Operator { + TensorFlowNotEqualOperator() : Operator(OperatorType::kNotEqual) {} }; // Global max reduction: computes the max of all of entries in the input array. @@ -1367,7 +1384,7 @@ struct TensorFlowGreaterEqualOperator : Operator { // TensorFlow equivalent: Max --- except that we only support the special case // of global reduction across all dimensions. struct TensorFlowMaxOperator : Operator { - TensorFlowMaxOperator() : Operator(OperatorType::kTensorFlowMax) {} + TensorFlowMaxOperator() : Operator(OperatorType::kMax) {} bool keep_dims = false; }; @@ -1380,7 +1397,7 @@ struct TensorFlowMaxOperator : Operator { // TensorFlow equivalent: Min --- except that we only support the special case // of global reduction across all dimensions. struct TensorFlowMinOperator : Operator { - TensorFlowMinOperator() : Operator(OperatorType::kTensorFlowMin) {} + TensorFlowMinOperator() : Operator(OperatorType::kMin) {} bool keep_dims = false; }; @@ -1393,7 +1410,7 @@ struct TensorFlowMinOperator : Operator { // // TensorFlow equivalent: Maximum struct TensorFlowMaximumOperator : Operator { - TensorFlowMaximumOperator() : Operator(OperatorType::kTensorFlowMaximum) {} + TensorFlowMaximumOperator() : Operator(OperatorType::kMaximum) {} }; // Element-wise minimum operator. Currently it only supports scalar as @@ -1405,14 +1422,13 @@ struct TensorFlowMaximumOperator : Operator { // // TensorFlow equivalent: Minimum struct TensorFlowMinimumOperator : Operator { - TensorFlowMinimumOperator() : Operator(OperatorType::kTensorFlowMinimum) {} + TensorFlowMinimumOperator() : Operator(OperatorType::kMinimum) {} }; // General TF operation, unsupported by tf.mini. Expected to be dropped by // graph transformations. struct TensorFlowUnsupportedOperator : Operator { - TensorFlowUnsupportedOperator() - : Operator(OperatorType::kTensorFlowUnsupported) {} + TensorFlowUnsupportedOperator() : Operator(OperatorType::kUnsupported) {} // The original TF operation type. Used for diagnostic purposes. string tensorflow_op; @@ -1625,8 +1641,8 @@ struct SparseToDenseOperator : Operator { // be used for the transient array at hand. The 'start' and 'end' values are // offsets from the start of the workspace buffer, expressed in bytes. struct Alloc { - int start = 0; - int end = 0; + int64 start = 0; + int64 end = 0; }; inline bool operator<(const Alloc& a, const Alloc& b) { diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc index 0f104d5e2d02dc852a2720c78995108a00924298..4c9f1aa4b0274b5123bb3baa9b9fca1463bda4c3 100644 --- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc @@ -48,7 +48,7 @@ bool ParseModelFlagsFromCommandLineFlags( "that information from the input file."), Flag("input_arrays", parsed_flags.input_arrays.bind(), parsed_flags.input_arrays.default_value(), - "Names of the output arrays, comma-separated. If not specified, " + "Names of the input arrays, comma-separated. If not specified, " "will try to read that information from the input file."), Flag("output_array", parsed_flags.output_array.bind(), parsed_flags.output_array.default_value(), diff --git a/tensorflow/contrib/lite/toco/python/BUILD b/tensorflow/contrib/lite/toco/python/BUILD index a954f1d6ba65f21cb99df226790f4bf4951581b1..93fe756a55d378fa205ff88be5e18aff586e5dca 100644 --- a/tensorflow/contrib/lite/toco/python/BUILD +++ b/tensorflow/contrib/lite/toco/python/BUILD @@ -12,6 +12,7 @@ cc_library( deps = [ "//tensorflow/contrib/lite/toco:model_flags_proto_cc", "//tensorflow/contrib/lite/toco:toco_flags_proto_cc", + "//tensorflow/contrib/lite/toco:toco_graphviz_dump_options", "//tensorflow/contrib/lite/toco:toco_port", "//tensorflow/contrib/lite/toco:toco_tooling", "//tensorflow/core:lib", diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.cc b/tensorflow/contrib/lite/toco/python/toco_python_api.cc index 5b1db852b4f8e89c1a591cfe18a0ab0aa2db04c9..d93e104038741e6e59608f04115854d611f1f9ae 100644 --- a/tensorflow/contrib/lite/toco/python/toco_python_api.cc +++ b/tensorflow/contrib/lite/toco/python/toco_python_api.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/python/toco_python_api.h" #include "tensorflow/contrib/lite/toco/toco_flags.pb.h" +#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" #include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/contrib/lite/toco/toco_tooling.h" #include "tensorflow/contrib/lite/toco/toco_types.h" @@ -62,7 +63,7 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, std::string input_contents_txt = ConvertArg(input_contents_txt_raw, &error); if (error) return nullptr; - // Use toco to produce new outputs + // Use TOCO to produce new outputs. toco::ModelFlags model_flags; if (!model_flags.ParseFromString(model_flags_proto_txt)) { LOG(FATAL) << "Model proto failed to parse." << std::endl; @@ -71,6 +72,16 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, if (!toco_flags.ParseFromString(toco_flags_proto_txt)) { LOG(FATAL) << "Toco proto failed to parse." << std::endl; } + + auto& dump_options = *GraphVizDumpOptions::singleton(); + if (toco_flags.has_dump_graphviz_dir()) { + dump_options.dump_graphviz = toco_flags.dump_graphviz_dir(); + } + if (toco_flags.has_dump_graphviz_include_video()) { + dump_options.dump_graphviz_video = toco_flags.dump_graphviz_include_video(); + } + + // Convert model. std::unique_ptr model = toco::Import(toco_flags, model_flags, input_contents_txt); toco::Transform(toco_flags, model.get()); diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD index e1025c66642d2860c5916bf7625f1c0403c9901c..a02f90988b2863900b6a735fd69aa1975a762338 100644 --- a/tensorflow/contrib/lite/toco/tflite/BUILD +++ b/tensorflow/contrib/lite/toco/tflite/BUILD @@ -24,6 +24,7 @@ cc_library( deps = [ ":types", "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/contrib/lite/toco:graph_transformations", "//tensorflow/contrib/lite/toco:model", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/memory", diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc index 5daa703c80b3b5d9152c5d21976260f21679a3f2..19722468079a32b76f6952db6ca818da470a03ac 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.cc +++ b/tensorflow/contrib/lite/toco/tflite/export.cc @@ -49,7 +49,7 @@ details::OperatorKey GetOperatorKey( const ::toco::Operator& op, const std::map>& ops_by_type) { string custom_code; - if (op.type == OperatorType::kTensorFlowUnsupported) { + if (op.type == OperatorType::kUnsupported) { const TensorFlowUnsupportedOperator& unsupported_op = static_cast(op); custom_code = unsupported_op.tensorflow_op; @@ -99,7 +99,8 @@ void LoadOperatorsMap( Offset>> ExportTensors( const Model& model, const details::TensorsMap& tensors_map, - FlatBufferBuilder* builder, std::vector* buffers_to_write) { + FlatBufferBuilder* builder, std::vector* buffers_to_write, + const std::set& variable_tensor_indices) { // In the end we will need to produce a vector sorted by the indices of the // tensors in the tensors_map. std::map> ordered_tensors; @@ -139,9 +140,11 @@ Offset>> ExportTensors( scale, zero_point); int index = tensors_map.at(tensor_name); + bool is_variable = + variable_tensor_indices.find(index) != variable_tensor_indices.end(); ordered_tensors[index] = CreateTensor(*builder, builder->CreateVector(shape), type, buffer_index, - builder->CreateString(tensor_name), q_param); + builder->CreateString(tensor_name), q_param, is_variable); } std::vector> tensor_vector; @@ -208,7 +211,7 @@ Offset>> ExportOperatorCodes( ordered_opcodes[op_index] = CreateOperatorCode(*builder, builtin_ops[name], 0, op_version); } else { - // This could be a kTensorFlowUnsupported, in which case we should be + // This could be a kUnsupported, in which case we should be // able to retrieve the original Tensorflow name from the OperatorKey, or // this could be a proper TOCO operator that is completely unknown to TF // Lite. @@ -239,7 +242,10 @@ Offset>> ExportOperators( const Model& model, const std::map>& ops_by_type, const details::OperatorsMap& operators_map, - const details::TensorsMap& tensors_map, FlatBufferBuilder* builder) { + const details::TensorsMap& tensors_map, FlatBufferBuilder* builder, + std::set* variable_tensor_indices) { + variable_tensor_indices->clear(); + // The operators are in execution order, so we just follow tf.mini order. std::vector> op_vector; for (const auto& op : model.operators) { @@ -256,18 +262,36 @@ Offset>> ExportOperators( int op_index = operators_map.at(GetOperatorKey(*op, ops_by_type)); - // This is a custom op unless we can find it in ops_by_type, and even then - // it could be a custom op (such as kTensorFlowUnsupported). + auto tflite_op_it = ops_by_type.find(op->type); + BaseOperator* tflite_op = tflite_op_it == ops_by_type.end() + ? nullptr + : tflite_op_it->second.get(); + // This is a custom op unless we can find it in ops_by_type, and even then + // it could be a custom op (such as kUnsupported). auto options = Options::Custom(0); - if (ops_by_type.count(op->type) != 0) { - options = ops_by_type.at(op->type)->Serialize(*op, builder); + + std::vector mutating_input_variables; + if (tflite_op) { + options = tflite_op->Serialize(*op, builder); + mutating_input_variables = tflite_op->GetMutatingInputVariables(*op); + + if (!mutating_input_variables.empty()) { + for (int i = 0; i < op->inputs.size(); ++i) { + if (!mutating_input_variables[i]) { + continue; + } + int32_t variable_tensor_index = tensors_map.at(op->inputs[i]); + variable_tensor_indices->insert(variable_tensor_index); + } + } } // The only supported CustomOptionFormat is FLEXBUFFERS now. op_vector.push_back(CreateOperator( *builder, op_index, builder->CreateVector(inputs), builder->CreateVector(outputs), options.type, options.builtin, - options.custom, ::tflite::CustomOptionsFormat_FLEXBUFFERS)); + options.custom, ::tflite::CustomOptionsFormat_FLEXBUFFERS, + builder->CreateVector(mutating_input_variables))); } return builder->CreateVector(op_vector); @@ -308,14 +332,12 @@ void Export( Array empty_array; buffers_to_write.push_back(&empty_array); - auto tensors = ExportTensors(model, tensors_map, &builder, &buffers_to_write); - auto inputs = ExportInputTensors(model, tensors_map, &builder); - auto outputs = ExportOutputTensors(model, tensors_map, &builder); - std::set error_summary; auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map, &builder, &error_summary); + const string fake_quant_operation_name = "FAKE_QUANT"; + if (error_summary.count(fake_quant_operation_name) != 0) { LOG(ERROR) << fake_quant_operation_name @@ -327,6 +349,21 @@ void Export( error_summary.erase(fake_quant_operation_name); } if (!allow_custom_ops && !error_summary.empty()) { + // Remove ExpandDims and ReorderAxes from unimplemented list unless they + // compose the list. Both ops are removed during graph transformations. + // However, if an op is unimplemented earlier in the model, the graph + // transformation is unable to run because the output shape is not defined. + // This causes unnecessary confusion during model conversion time. + std::set error_summary_final; + for (const auto& op_type : error_summary) { + if (op_type != "ReorderAxes" && op_type != "ExpandDims") { + error_summary_final.insert(op_type); + } + } + if (error_summary_final.empty()) { + error_summary_final = error_summary; + } + LOG(QFATAL) << "Some of the operators in the model are not supported by " "the standard TensorFlow Lite runtime. If you have a custom " @@ -334,14 +371,21 @@ void Export( "--allow_custom_ops, or by setting allow_custom_ops=True " "when calling tf.contrib.lite.toco_convert(). Here is a list " "of operators for which you will need custom implementations: " - << absl::StrJoin(error_summary, ", ") << "."; + << absl::StrJoin(error_summary_final, ", ") << "."; } - auto ops = - ExportOperators(model, ops_by_type, operators_map, tensors_map, &builder); + std::set variable_tensor_indices; + auto ops = ExportOperators(model, ops_by_type, operators_map, tensors_map, + &builder, &variable_tensor_indices); + + auto tensors = ExportTensors(model, tensors_map, &builder, &buffers_to_write, + variable_tensor_indices); + auto inputs = ExportInputTensors(model, tensors_map, &builder); + auto outputs = ExportOutputTensors(model, tensors_map, &builder); // TODO(aselle): add support to toco for multiple subgraphs. - auto subgraph = CreateSubGraph(builder, tensors, inputs, outputs, ops); + auto subgraph = CreateSubGraph(builder, tensors, inputs, outputs, ops, + /* name */ 0); std::vector> subgraphs = {subgraph}; auto buffers = ExportBuffers(model, buffers_to_write, &builder); diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h index 098d2163e6c2fe26f3cb9cdf9959df62a1a4baf0..58ea5c725c378827aac79f2a5a2cdca59ccc0162 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.h +++ b/tensorflow/contrib/lite/toco/tflite/export.h @@ -45,7 +45,7 @@ namespace details { using TensorsMap = std::unordered_map; // A key to identify an operator. -// Only when `type` is `kTensorFlowUnsupported`, `custom_code` is filled to +// Only when `type` is `kUnsupported`, `custom_code` is filled to // identify which operation is used. struct OperatorKey { OperatorKey(OperatorType type, const std::string& custom_code, int version) diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc index 409e7d72a57076ec2832c5d12b52829477624f74..d1fdbcb8e9131e1d65fa32ca0395bbc17b2014e7 100644 --- a/tensorflow/contrib/lite/toco/tflite/export_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc @@ -73,8 +73,8 @@ TEST_F(ExportTest, LoadOperatorsMap) { EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "", 1)]); EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "", 1)]); EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "", 1)]); - EXPECT_EQ(3, operators[details::OperatorKey( - OperatorType::kTensorFlowUnsupported, "MyCrazyOp", 1)]); + EXPECT_EQ(3, operators[details::OperatorKey(OperatorType::kUnsupported, + "MyCrazyOp", 1)]); } TEST_F(ExportTest, Export) { diff --git a/tensorflow/contrib/lite/toco/tflite/import.cc b/tensorflow/contrib/lite/toco/tflite/import.cc index c0e7ab2ef57ed8edf1b7cda08c64f6ae66172af3..d1867bd4fa46a8a9dcd4c6abd4ef20b82c3854b4 100644 --- a/tensorflow/contrib/lite/toco/tflite/import.cc +++ b/tensorflow/contrib/lite/toco/tflite/import.cc @@ -113,15 +113,35 @@ void ImportOperators( << operators_table.size(); } string opname = operators_table.at(index); + + // Find and use the appropriate operator deserialization factory. + std::unique_ptr new_op = nullptr; if (ops_by_name.count(opname) == 0) { - LOG(FATAL) << "Op '" << opname << "' not supported"; + string effective_opname = "TENSORFLOW_UNSUPPORTED"; + if (ops_by_name.count(effective_opname) == 0) { + LOG(FATAL) << "Internal logic error: TENSORFLOW_UNSUPPORTED not found."; + } + new_op = ops_by_name.at(effective_opname) + ->Deserialize(input_op->builtin_options(), + input_op->custom_options()); + if (new_op->type == OperatorType::kUnsupported) { + auto* unsupported_op = + static_cast(new_op.get()); + unsupported_op->tensorflow_op = opname; + // TODO(b/109932940): Remove this when quantized is removed. + // For now, we assume all ops are quantized. + unsupported_op->quantized = true; + } else { + LOG(FATAL) << "Expected a TensorFlowUnsupportedOperator"; + } + } else { + new_op = ops_by_name.at(opname)->Deserialize(input_op->builtin_options(), + input_op->custom_options()); } - - auto new_op = ops_by_name.at(opname)->Deserialize( - input_op->builtin_options(), input_op->custom_options()); model->operators.emplace_back(new_op.release()); auto* op = model->operators.back().get(); + // Make sure all the inputs and outputs are hooked up. auto inputs = input_op->inputs(); for (int i = 0; i < inputs->Length(); i++) { auto input_index = inputs->Get(i); diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index a8518adefcf39221020a4cd531d0a4fe33f9b5ae..290a925c1ef68315473fcd06006114836cd08a4f 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/toco/tflite/operator.h" +// TODO(ycling): Consider refactoring to extract the LSTM definition out of +// graph_transformation module. +#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h" #include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h" #include "tensorflow/contrib/lite/toco/tflite/custom_operator.h" #include "tensorflow/contrib/lite/toco/tflite/simple_operator.h" @@ -668,16 +671,55 @@ class Lstm : public BuiltinOperator GetMutatingInputVariables( + const Operator& op) const override { + const auto& lstm_op = static_cast(op); + + std::vector mutating_input_variables(op.inputs.size(), false); + switch (lstm_op.kernel_type) { + case LstmCellOperator::KERNEL_FULL: { + mutating_input_variables[kInputActivationStateTensor] = true; + mutating_input_variables[kInputCellStateTensor] = true; + break; + } + case LstmCellOperator::KERNEL_BASIC: { + mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true; + mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true; + break; + } + } + return mutating_input_variables; + } +}; + +class Mean : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateReducerOptions(*builder, op.keep_dims); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->keep_dims = options.keep_dims(); + } + + int GetVersion(const Operator& op) const override { return 1; } }; -class Mean : public BuiltinOperator { +class Sum + : public BuiltinOperator { public: using BuiltinOperator::BuiltinOperator; flatbuffers::Offset WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { - return ::tflite::CreateMeanOptions(*builder, op.keep_dims); + return ::tflite::CreateReducerOptions(*builder, op.keep_dims); } void ReadOptions(const TfLiteOptions& options, @@ -876,6 +918,26 @@ class ExpandDims int GetVersion(const Operator& op) const override { return 1; } }; +class Shape + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateShapeOptions( + *builder, DataType::Serialize(op.output_data_type)); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->output_data_type = DataType::Deserialize(options.out_type()); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + class TensorFlowUnsupported : public BaseOperator { public: using BaseOperator::BaseOperator; @@ -936,6 +998,20 @@ class TensorFlowUnsupported : public BaseOperator { fbb->Bool(key, attr.b()); has_valid_attr = true; break; + case tensorflow::AttrValue::kList: + if (attr.list().i_size() > 0) { + auto start = fbb->StartVector(key); + for (const int64_t v : attr.list().i()) { + fbb->Add(v); + } + fbb->EndVector(start, /*typed=*/true, /*fixed=*/false); + has_valid_attr = true; + } else { + LOG(WARNING) + << "Ignoring unsupported type in list attribute with key '" + << key << "'"; + } + break; default: LOG(WARNING) << "Ignoring unsupported attribute type with key '" << key << "'"; @@ -972,6 +1048,14 @@ class TensorFlowUnsupported : public BaseOperator { case flexbuffers::TYPE_BOOL: (*attr)[key].set_b(value.AsBool()); break; + case flexbuffers::TYPE_VECTOR_INT: { + auto* list = (*attr)[key].mutable_list(); + const auto& vector = value.AsTypedVector(); + for (size_t i = 0; i < vector.size(); i++) { + list->add_i(vector[i].AsInt64()); + } + break; + } default: LOG(WARNING) << "Ignoring unsupported attribute type with key '" << key << "'"; @@ -1030,8 +1114,8 @@ std::vector> BuildOperatorList() { ops.emplace_back(new Pad(::tflite::BuiltinOperator_PAD, OperatorType::kPad)); ops.emplace_back( new PadV2(::tflite::BuiltinOperator_PADV2, OperatorType::kPadV2)); - ops.emplace_back(new Reshape(::tflite::BuiltinOperator_RESHAPE, - OperatorType::kTensorFlowReshape)); + ops.emplace_back( + new Reshape(::tflite::BuiltinOperator_RESHAPE, OperatorType::kReshape)); ops.emplace_back( new Softmax(::tflite::BuiltinOperator_SOFTMAX, OperatorType::kSoftmax)); ops.emplace_back(new SpaceToDepth(::tflite::BuiltinOperator_SPACE_TO_DEPTH, @@ -1042,12 +1126,13 @@ std::vector> BuildOperatorList() { OperatorType::kTranspose)); ops.emplace_back( new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean)); + ops.emplace_back(new Sum(::tflite::BuiltinOperator_SUM, OperatorType::kSum)); ops.emplace_back(new ResizeBilinear(::tflite::BuiltinOperator_RESIZE_BILINEAR, OperatorType::kResizeBilinear)); ops.emplace_back( new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze)); - ops.emplace_back(new Split(::tflite::BuiltinOperator_SPLIT, - OperatorType::kTensorFlowSplit)); + ops.emplace_back( + new Split(::tflite::BuiltinOperator_SPLIT, OperatorType::kSplit)); ops.emplace_back(new StridedSlice(::tflite::BuiltinOperator_STRIDED_SLICE, OperatorType::kStridedSlice)); ops.emplace_back( @@ -1059,27 +1144,27 @@ std::vector> BuildOperatorList() { ops.emplace_back( new ArgMax(::tflite::BuiltinOperator_ARG_MAX, OperatorType::kArgMax)); ops.emplace_back( - new Tile(::tflite::BuiltinOperator_TILE, OperatorType::kTensorFlowTile)); + new Tile(::tflite::BuiltinOperator_TILE, OperatorType::kTile)); ops.emplace_back(new ExpandDims(::tflite::BuiltinOperator_EXPAND_DIMS, OperatorType::kExpandDims)); ops.emplace_back(new TransposeConv(::tflite::BuiltinOperator_TRANSPOSE_CONV, OperatorType::kTransposeConv)); ops.emplace_back(new SparseToDense(::tflite::BuiltinOperator_SPARSE_TO_DENSE, OperatorType::kSparseToDense)); + ops.emplace_back( + new Shape(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape)); // Custom Operators. ops.emplace_back( new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace)); ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant)); - ops.emplace_back(new TensorFlowUnsupported( - "TENSORFLOW_UNSUPPORTED", OperatorType::kTensorFlowUnsupported)); + ops.emplace_back(new TensorFlowUnsupported("TENSORFLOW_UNSUPPORTED", + OperatorType::kUnsupported)); // There operators are supported by Toco, but not by TF Lite, and has no // attributes. ops.emplace_back( new SimpleOperator("ADDN", OperatorType::kAddN)); - ops.emplace_back(new SimpleOperator( - "RSQRT", OperatorType::kTensorFlowRsqrt)); // Simple Operators. ops.emplace_back(new SimpleOperator( "DEQUANTIZE", OperatorType::kDequantize)); @@ -1101,23 +1186,33 @@ std::vector> BuildOperatorList() { ops.emplace_back(new SimpleOperator( "LOG_SOFTMAX", OperatorType::kLogSoftmax)); ops.emplace_back(new SimpleOperator( - "MAXIMUM", OperatorType::kTensorFlowMaximum)); + "MAXIMUM", OperatorType::kMaximum)); // Element-wise Maximum ops.emplace_back(new SimpleOperator( - "MINIMUM", OperatorType::kTensorFlowMinimum)); + "MINIMUM", OperatorType::kMinimum)); // Element-wise Minimum ops.emplace_back(new SimpleOperator( - "GREATER", OperatorType::kTensorFlowGreater)); + "GREATER", OperatorType::kGreater)); ops.emplace_back(new SimpleOperator( - "GREATER_EQUAL", OperatorType::kTensorFlowGreaterEqual)); - ops.emplace_back(new SimpleOperator( - "LESS", OperatorType::kTensorFlowLess)); + "GREATER_EQUAL", OperatorType::kGreaterEqual)); + ops.emplace_back( + new SimpleOperator("LESS", OperatorType::kLess)); ops.emplace_back(new SimpleOperator( - "LESS_EQUAL", OperatorType::kTensorFlowLessEqual)); + "LESS_EQUAL", OperatorType::kLessEqual)); + ops.emplace_back(new SimpleOperator( + "EQUAL", OperatorType::kEqual)); + ops.emplace_back(new SimpleOperator( + "NOT_EQUAL", OperatorType::kNotEqual)); ops.emplace_back(new SimpleOperator("NEG", OperatorType::kNeg)); ops.emplace_back( new SimpleOperator("SELECT", OperatorType::kSelect)); ops.emplace_back( new SimpleOperator("SLICE", OperatorType::kSlice)); + // Element-wise operator ops.emplace_back(new SimpleOperator("SIN", OperatorType::kSin)); + ops.emplace_back(new SimpleOperator("LOG", OperatorType::kLog)); + ops.emplace_back( + new SimpleOperator("SQRT", OperatorType::kSqrt)); + ops.emplace_back(new SimpleOperator( + "RSQRT", OperatorType::kRsqrt)); return ops; } diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h index 5e9c20e40dd6274e0839379883b6dbe53064a0fc..d9ea23edf2b08146773ca58762623397e0f6257c 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.h +++ b/tensorflow/contrib/lite/toco/tflite/operator.h @@ -87,6 +87,17 @@ class BaseOperator { // overridden. (See example in `operator_test.cc`) virtual int GetVersion(const Operator& op) const = 0; + // Given a Toco `Operator`, return a list of booleans indicating the op + // mutates which input variables. + // * If the op mutates any input variables, it should return a list of bool + // with the same length as inputs. + // * Otherwise, it will return an empty list. + virtual std::vector GetMutatingInputVariables( + const Operator& op) const { + // Most ops don't have variable tensors. This function can be overridden. + return std::vector(); + } + private: string name_; OperatorType type_; diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index d63c99a5f992beca8ba6b5cd034a6f370304eae1..79c8e5d738ab7da12a279c86df8b03d39a924fa1 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -74,8 +74,10 @@ class OperatorTest : public ::testing::Test { auto new_toco_op = op.Deserialize(output_options->builtin_options(), output_options->custom_options()); - CHECK(dynamic_cast(new_toco_op.get())) - << "Cannot cast " << HelpfulOperatorTypeName(*new_toco_op) << " to " + CHECK(new_toco_op->type == toco_op.type) + << "The type of the serialized and deserialized" + << HelpfulOperatorTypeName(*new_toco_op) + << " does not match the type of the original " << HelpfulOperatorTypeName(toco_op); return std::unique_ptr(dynamic_cast(new_toco_op.release())); @@ -110,15 +112,20 @@ TEST_F(OperatorTest, SimpleOperators) { CheckSimpleOperator("LOG_SOFTMAX", OperatorType::kLogSoftmax); CheckSimpleOperator( - "MAXIMUM", OperatorType::kTensorFlowMaximum); + "MAXIMUM", OperatorType::kMaximum); // Element-wise Maximum CheckSimpleOperator( - "MINIMUM", OperatorType::kTensorFlowMinimum); - CheckSimpleOperator("LESS", - OperatorType::kTensorFlowLess); + "MINIMUM", OperatorType::kMinimum); // Element-wise Minimum + CheckSimpleOperator("LESS", OperatorType::kLess); CheckSimpleOperator("NEG", OperatorType::kNeg); CheckSimpleOperator("SELECT", OperatorType::kSelect); CheckSimpleOperator("SLICE", OperatorType::kSlice); CheckSimpleOperator("SIN", OperatorType::kSin); + CheckSimpleOperator("EQUAL", OperatorType::kEqual); + CheckSimpleOperator("NOT_EQUAL", + OperatorType::kNotEqual); + CheckSimpleOperator("LOG", OperatorType::kLog); + CheckSimpleOperator("SQRT", OperatorType::kSqrt); + CheckSimpleOperator("RSQRT", OperatorType::kRsqrt); } TEST_F(OperatorTest, BuiltinAdd) { @@ -247,7 +254,7 @@ TEST_F(OperatorTest, BuiltinReshape) { TensorFlowReshapeOperator op; op.shape = {1, 2, 4, 5, 8}; auto output_toco_op = SerializeAndDeserialize( - GetOperator("RESHAPE", OperatorType::kTensorFlowReshape), op); + GetOperator("RESHAPE", OperatorType::kReshape), op); EXPECT_EQ(op.shape, output_toco_op->shape); } @@ -270,8 +277,8 @@ TEST_F(OperatorTest, BuiltinSpaceToDepth) { TEST_F(OperatorTest, CustomSplit) { TensorFlowSplitOperator op; op.num_split = 123; - auto output_toco_op = SerializeAndDeserialize( - GetOperator("SPLIT", OperatorType::kTensorFlowSplit), op); + auto output_toco_op = + SerializeAndDeserialize(GetOperator("SPLIT", OperatorType::kSplit), op); EXPECT_EQ(op.num_split, output_toco_op->num_split); } @@ -420,6 +427,14 @@ TEST_F(OperatorTest, BuiltinTransposeConv) { EXPECT_EQ(op.padding.type, output_toco_op->padding.type); } +TEST_F(OperatorTest, BuiltinShape) { + TensorFlowShapeOperator op; + op.output_data_type = ArrayDataType::kInt64; + auto output_toco_op = + SerializeAndDeserialize(GetOperator("SHAPE", OperatorType::kShape), op); + EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type); +} + TEST_F(OperatorTest, BuiltinSparseToDense) { SparseToDenseOperator op; op.validate_indices = false; @@ -439,12 +454,17 @@ TEST_F(OperatorTest, TensorFlowUnsupported) { (*attr)["str_attr"].set_s("Hello World"); (*attr)["int_attr"].set_i(17); (*attr)["bool_attr"].set_b(true); + { + auto* list = (*attr)["list_int_attr"].mutable_list(); + list->add_i(1); + list->add_i(20); + list->add_i(1LL << 40); + list->add_i(-(1LL << 40)); + } node_def.SerializeToString(&op.tensorflow_node_def); - auto output_toco_op = - SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED", - OperatorType::kTensorFlowUnsupported), - op); + auto output_toco_op = SerializeAndDeserialize( + GetOperator("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported), op); ::tensorflow::NodeDef output_node_def; output_node_def.ParseFromString(output_toco_op->tensorflow_node_def); @@ -453,15 +473,22 @@ TEST_F(OperatorTest, TensorFlowUnsupported) { EXPECT_EQ("Hello World", output_attr.at("str_attr").s()); EXPECT_EQ(17, output_attr.at("int_attr").i()); EXPECT_EQ(true, output_attr.at("bool_attr").b()); + + { + const auto& list = output_attr.at("list_int_attr").list(); + ASSERT_EQ(4, list.i_size()); + EXPECT_EQ(1, list.i(0)); + EXPECT_EQ(20, list.i(1)); + EXPECT_EQ(1LL << 40, list.i(2)); + EXPECT_EQ(-(1LL << 40), list.i(3)); + } } TEST_F(OperatorTest, TensorFlowUnsupportedWithoutAttr) { TensorFlowUnsupportedOperator op; op.tensorflow_op = "MyCustomUnsupportedOp"; - auto output_toco_op = - SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED", - OperatorType::kTensorFlowUnsupported), - op); + auto output_toco_op = SerializeAndDeserialize( + GetOperator("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported), op); ::tensorflow::NodeDef output_node_def; output_node_def.ParseFromString(output_toco_op->tensorflow_node_def); diff --git a/tensorflow/contrib/lite/toco/tflite/types.cc b/tensorflow/contrib/lite/toco/tflite/types.cc index 4867c3a62e68406428644cd05bddf212008c2656..42c5d7e8ebc3a7b90963a92843af616d9e6532d6 100644 --- a/tensorflow/contrib/lite/toco/tflite/types.cc +++ b/tensorflow/contrib/lite/toco/tflite/types.cc @@ -88,6 +88,8 @@ void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) { switch (array_data_type) { case ArrayDataType::kFloat: return ::tflite::TensorType_FLOAT32; + case ArrayDataType::kInt16: + return ::tflite::TensorType_INT16; case ArrayDataType::kInt32: return ::tflite::TensorType_INT32; case ArrayDataType::kInt64: @@ -109,6 +111,8 @@ ArrayDataType DataType::Deserialize(int tensor_type) { switch (::tflite::TensorType(tensor_type)) { case ::tflite::TensorType_FLOAT32: return ArrayDataType::kFloat; + case ::tflite::TensorType_INT16: + return ArrayDataType::kInt16; case ::tflite::TensorType_INT32: return ArrayDataType::kInt32; case ::tflite::TensorType_INT64: @@ -131,6 +135,8 @@ flatbuffers::Offset> DataBuffer::Serialize( switch (array.data_type) { case ArrayDataType::kFloat: return CopyBuffer(array, builder); + case ArrayDataType::kInt16: + return CopyBuffer(array, builder); case ArrayDataType::kInt32: return CopyBuffer(array, builder); case ArrayDataType::kInt64: @@ -154,6 +160,8 @@ void DataBuffer::Deserialize(const ::tflite::Tensor& tensor, switch (tensor.type()) { case ::tflite::TensorType_FLOAT32: return CopyBuffer(buffer, array); + case ::tflite::TensorType_INT16: + return CopyBuffer(buffer, array); case ::tflite::TensorType_INT32: return CopyBuffer(buffer, array); case ::tflite::TensorType_INT64: diff --git a/tensorflow/contrib/lite/toco/tflite/types_test.cc b/tensorflow/contrib/lite/toco/tflite/types_test.cc index 564f303b9bb41a777633ecabd666aa93ec3faefe..8c6ef95bfab0a5e9b410748eabf9570eec52c2e0 100644 --- a/tensorflow/contrib/lite/toco/tflite/types_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/types_test.cc @@ -151,6 +151,12 @@ TEST(DataBuffer, Int32) { ::testing::ElementsAre(1, 1 << 30)); } +TEST(DataBuffer, Int16) { + Array recovered = ToFlatBufferAndBack({1, 1 << 14}); + EXPECT_THAT(recovered.GetBuffer().data, + ::testing::ElementsAre(1, 1 << 14)); +} + TEST(DataBuffer, String) { Array recovered = ToFlatBufferAndBack( {"AA", "BBB", "Best. String. Ever."}); diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto index 4fe57879fb0f38a21aac01283bc68077aa4be771..ad4e94ded9f9730842a257e065d9aec2b1cbfac8 100644 --- a/tensorflow/contrib/lite/toco/toco_flags.proto +++ b/tensorflow/contrib/lite/toco/toco_flags.proto @@ -174,4 +174,13 @@ message TocoFlags { // Computation is still done in float, but reduces model size (at the cost of // accuracy and latency). optional bool quantize_weights = 20 [default = false]; + + // Full filepath of folder to dump the graphs at various stages of processing + // GraphViz .dot files. Preferred over --output_format=GRAPHVIZ_DOT in order + // to keep the requirements of the output file. + optional string dump_graphviz_dir = 24; + + // Boolean indicating whether to dump the graph after every graph + // transformation. + optional bool dump_graphviz_include_video = 25; } diff --git a/tensorflow/contrib/lite/toco/toco_port.cc b/tensorflow/contrib/lite/toco/toco_port.cc index 3a5911c28dc5462b5d3747f6af6aa82026a23466..de76fd4032d24eff8a6c2fd0c16a911b9c00186b 100644 --- a/tensorflow/contrib/lite/toco/toco_port.cc +++ b/tensorflow/contrib/lite/toco/toco_port.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #if defined(__ANDROID__) && defined(__ARM_ARCH_7A__) @@ -61,8 +63,12 @@ void CheckInitGoogleIsDone(const char* message) { namespace file { // Conversion to our wrapper Status. -Status ToStatus(const ::util::Status& uts) { - return Status(uts.ok(), uts.error_message()); +tensorflow::Status ToStatus(const ::util::Status& uts) { + if (!uts.ok()) { + return tensorflow::Status(tensorflow::errors::Code(uts.error_code()), + uts.error_message()); + } + return tensorflow::Status::OK(); } // Conversion to our wrapper Options. @@ -71,7 +77,7 @@ toco::port::file::Options ToOptions(const ::file::Options& options) { return Options(); } -Status Writable(const string& filename) { +tensorflow::Status Writable(const string& filename) { File* f = nullptr; const auto status = ::file::Open(filename, "w", &f, ::file::Defaults()); if (f) { @@ -80,22 +86,24 @@ Status Writable(const string& filename) { return ToStatus(status); } -Status Readable(const string& filename, const file::Options& options) { +tensorflow::Status Readable(const string& filename, + const file::Options& options) { return ToStatus(::file::Readable(filename, ::file::Defaults())); } -Status Exists(const string& filename, const file::Options& options) { +tensorflow::Status Exists(const string& filename, + const file::Options& options) { auto status = ::file::Exists(filename, ::file::Defaults()); return ToStatus(status); } -Status GetContents(const string& filename, string* contents, - const file::Options& options) { +tensorflow::Status GetContents(const string& filename, string* contents, + const file::Options& options) { return ToStatus(::file::GetContents(filename, contents, ::file::Defaults())); } -Status SetContents(const string& filename, const string& contents, - const file::Options& options) { +tensorflow::Status SetContents(const string& filename, const string& contents, + const file::Options& options) { return ToStatus(::file::SetContents(filename, contents, ::file::Defaults())); } @@ -139,37 +147,42 @@ void CheckInitGoogleIsDone(const char* message) { namespace file { -Status Writable(const string& filename) { +tensorflow::Status Writable(const string& filename) { FILE* f = fopen(filename.c_str(), "w"); if (f) { fclose(f); - return Status(true, ""); + return tensorflow::Status::OK(); } - return Status(false, "not writable"); + return tensorflow::errors::NotFound("not writable"); } -Status Readable(const string& filename, const file::Options& options) { +tensorflow::Status Readable(const string& filename, + const file::Options& options) { FILE* f = fopen(filename.c_str(), "r"); if (f) { fclose(f); - return Status(true, ""); + return tensorflow::Status::OK(); } - return Status(false, "not readable"); + return tensorflow::errors::NotFound("not readable"); } -Status Exists(const string& filename, const file::Options& options) { +tensorflow::Status Exists(const string& filename, + const file::Options& options) { struct stat statbuf; int ret = stat(filename.c_str(), &statbuf); - return Status(ret != -1, ""); + if (ret == -1) { + return tensorflow::errors::NotFound("file doesn't exist"); + } + return tensorflow::Status::OK(); } -Status GetContents(const string& path, string* output, - const file::Options& options) { +tensorflow::Status GetContents(const string& path, string* output, + const file::Options& options) { output->clear(); int fd = open(path.c_str(), O_RDONLY); if (fd == -1) { - return Status(false, "can't open() for read"); + return tensorflow::errors::NotFound("can't open() for read"); } // Direct read, for speed. @@ -180,25 +193,25 @@ Status GetContents(const string& path, string* output, if (size == 0) { // Done. close(fd); - return Status(true, ""); + return tensorflow::Status::OK(); } else if (size == -1) { // Error. close(fd); - return Status(false, "error during read()"); + return tensorflow::errors::Internal("error during read()"); } else { output->append(buffer, size); } } CHECK(0); - return Status(false, "internal error"); + return tensorflow::errors::Internal("internal error"); } -Status SetContents(const string& filename, const string& contents, - const file::Options& options) { +tensorflow::Status SetContents(const string& filename, const string& contents, + const file::Options& options) { int fd = open(filename.c_str(), O_WRONLY | O_CREAT, 0664); if (fd == -1) { - return Status(false, "can't open() for write"); + return tensorflow::errors::Internal("can't open() for write"); } size_t i = 0; @@ -207,13 +220,13 @@ Status SetContents(const string& filename, const string& contents, ssize_t written = write(fd, &contents[i], to_write); if (written == -1) { close(fd); - return Status(false, "write() error"); + return tensorflow::errors::Internal("write() error"); } i += written; } close(fd); - return Status(true, ""); + return tensorflow::Status::OK(); } string JoinPath(const string& base, const string& filename) { diff --git a/tensorflow/contrib/lite/toco/toco_port.h b/tensorflow/contrib/lite/toco/toco_port.h index b00b1e89e856190787d2d40096c9a5321bd80604..17f82b9dd7dcc633aa204038b6d965f4eb6967bb 100644 --- a/tensorflow/contrib/lite/toco/toco_port.h +++ b/tensorflow/contrib/lite/toco/toco_port.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "google/protobuf/text_format.h" #include "tensorflow/contrib/lite/toco/format_port.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/platform.h" #if defined(PLATFORM_GOOGLE) @@ -54,26 +55,6 @@ double round(double x); namespace toco { namespace port { -class Status { - public: - static Status OK() { return Status(true, ""); } - - // Create a failed status with no message. - Status() {} - - Status(bool ok, const string& message) : ok_(ok), message_(message) {} - - void AppendMessage(const string& message) { message_ += message; } - - bool ok() const { return ok_; } - - const string error_message() const { return message_; } - - private: - bool ok_ = false; - string message_; -}; - void InitGoogle(const char* usage, int* argc, char*** argv, bool remove_flags); void CheckInitGoogleIsDone(const char* message); @@ -83,14 +64,14 @@ inline Options Defaults() { Options o; return o; } -Status GetContents(const string& filename, string* contents, - const Options& options); -Status SetContents(const string& filename, const string& contents, - const Options& options); +tensorflow::Status GetContents(const string& filename, string* contents, + const Options& options); +tensorflow::Status SetContents(const string& filename, const string& contents, + const Options& options); string JoinPath(const string& base, const string& filename); -Status Writable(const string& filename); -Status Readable(const string& filename, const Options& options); -Status Exists(const string& filename, const Options& options); +tensorflow::Status Writable(const string& filename); +tensorflow::Status Readable(const string& filename, const Options& options); +tensorflow::Status Exists(const string& filename, const Options& options); } // namespace file // Copy `src` string to `dest`. User must ensure `dest` has enough space. diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 1fe76f8163cdf23b27f8baaf2d9c6d99b1aa3747..2534d1ef2ad3409a42836ad9470d8ac53d62894a 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -34,11 +34,11 @@ limitations under the License. namespace toco { namespace { -// CHECK-fails if the model contains a kTensorFlowUnsupported operation. +// CHECK-fails if the model contains a kUnsupported operation. void CheckUnsupportedOperations(const Model& model) { std::set unsupported_ops; for (auto& op : model.operators) { - if (op->type == OperatorType::kTensorFlowUnsupported) { + if (op->type == OperatorType::kUnsupported) { unsupported_ops.insert( static_cast(op.get()) ->tensorflow_op); @@ -56,6 +56,7 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ConvertSqueezeToReshape); transformations->Add(new ConvertTrivialAddNToAdd); transformations->Add(new ConvertTrivialStackToReshape); + transformations->Add(new ConvertTrivialTileToConcat); transformations->Add(new ConvertTrivialTransposeToReshape); transformations->Add(new ConvertReorderAxes); transformations->Add(new ResolveReshapeAttributes); @@ -76,6 +77,7 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveTensorFlowMatMul); transformations->Add(new FuseBinaryIntoPrecedingAffine); transformations->Add(new FuseBinaryIntoFollowingAffine); + transformations->Add(new FuseBroadcastIntoFollowingBinary); transformations->Add(new MergeReshapeIntoPrecedingTranspose); transformations->Add(new ReorderElementwiseUnary); transformations->Add(new ReorderReshapeTranspose); @@ -94,7 +96,6 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveTensorFlowMerge); transformations->Add(new ResolveSqueezeAttributes); transformations->Add(new ResolveTensorFlowSwitch); - transformations->Add(new ResolveTensorFlowTile); transformations->Add(new ResolveTensorFlowConcat); transformations->Add(new ResolveMultiplyByZero); transformations->Add(new IdentifyDilatedConv); diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index fe7bed885d8003ddef015af1bd846eef43fa7f47..a52c812ef45a3f82c6ca7812067e5a4b9bda3a67 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -30,7 +30,7 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/dump_graphviz.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" namespace toco { @@ -338,23 +338,23 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Div) HANDLE_OPERATORTYPENAME_CASE(Tanh) HANDLE_OPERATORTYPENAME_CASE(Sin) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowAll) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowAssert) + HANDLE_OPERATORTYPENAME_CASE(All) + HANDLE_OPERATORTYPENAME_CASE(Assert) HANDLE_OPERATORTYPENAME_CASE(ExpandDims) HANDLE_OPERATORTYPENAME_CASE(Fill) HANDLE_OPERATORTYPENAME_CASE(FloorMod) HANDLE_OPERATORTYPENAME_CASE(FloorDiv) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreater) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreaterEqual) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowIdentity) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowLess) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowLessEqual) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowMatMul) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowMax) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowMaximum) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowMerge) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowMin) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowMinimum) + HANDLE_OPERATORTYPENAME_CASE(Greater) + HANDLE_OPERATORTYPENAME_CASE(GreaterEqual) + HANDLE_OPERATORTYPENAME_CASE(Identity) + HANDLE_OPERATORTYPENAME_CASE(Less) + HANDLE_OPERATORTYPENAME_CASE(LessEqual) + HANDLE_OPERATORTYPENAME_CASE(MatMul) + HANDLE_OPERATORTYPENAME_CASE(Max) // Reduction Max + HANDLE_OPERATORTYPENAME_CASE(Maximum) // Element-wise Maximum + HANDLE_OPERATORTYPENAME_CASE(Merge) + HANDLE_OPERATORTYPENAME_CASE(Min) // Reduction Min + HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum HANDLE_OPERATORTYPENAME_CASE(Neg) HANDLE_OPERATORTYPENAME_CASE(Pad) HANDLE_OPERATORTYPENAME_CASE(PadV2) @@ -362,22 +362,22 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Stack) HANDLE_OPERATORTYPENAME_CASE(Range) HANDLE_OPERATORTYPENAME_CASE(Rank) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowReshape) + HANDLE_OPERATORTYPENAME_CASE(Reshape) HANDLE_OPERATORTYPENAME_CASE(Squeeze) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowRsqrt) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowShape) + HANDLE_OPERATORTYPENAME_CASE(Rsqrt) + HANDLE_OPERATORTYPENAME_CASE(Shape) HANDLE_OPERATORTYPENAME_CASE(Slice) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowSplit) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowSqrt) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowSquare) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowSwitch) + HANDLE_OPERATORTYPENAME_CASE(Split) + HANDLE_OPERATORTYPENAME_CASE(Sqrt) + HANDLE_OPERATORTYPENAME_CASE(Square) + HANDLE_OPERATORTYPENAME_CASE(Switch) HANDLE_OPERATORTYPENAME_CASE(Sub) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowSum) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowTile) + HANDLE_OPERATORTYPENAME_CASE(Sum) + HANDLE_OPERATORTYPENAME_CASE(Tile) HANDLE_OPERATORTYPENAME_CASE(Transpose) HANDLE_OPERATORTYPENAME_CASE(TransposeConv) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcat) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcatV2) + HANDLE_OPERATORTYPENAME_CASE(Concat) + HANDLE_OPERATORTYPENAME_CASE(ConcatV2) HANDLE_OPERATORTYPENAME_CASE(Cast) HANDLE_OPERATORTYPENAME_CASE(Floor) HANDLE_OPERATORTYPENAME_CASE(Gather) @@ -388,12 +388,14 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Svdf) HANDLE_OPERATORTYPENAME_CASE(ArgMax) HANDLE_OPERATORTYPENAME_CASE(TopK_V2) - HANDLE_OPERATORTYPENAME_CASE(TensorFlowUnsupported) + HANDLE_OPERATORTYPENAME_CASE(Unsupported) HANDLE_OPERATORTYPENAME_CASE(Exp) HANDLE_OPERATORTYPENAME_CASE(DynamicPartition) HANDLE_OPERATORTYPENAME_CASE(DynamicStitch) HANDLE_OPERATORTYPENAME_CASE(Select) HANDLE_OPERATORTYPENAME_CASE(SparseToDense) + HANDLE_OPERATORTYPENAME_CASE(Equal) + HANDLE_OPERATORTYPENAME_CASE(NotEqual) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE @@ -401,7 +403,7 @@ const char* OperatorTypeName(OperatorType type) { } string HelpfulOperatorTypeName(const Operator& op) { - if (op.type == OperatorType::kTensorFlowUnsupported) { + if (op.type == OperatorType::kUnsupported) { return toco::port::StringF( "(Unsupported TensorFlow op: %s)", static_cast(op).tensorflow_op); @@ -411,16 +413,20 @@ string HelpfulOperatorTypeName(const Operator& op) { bool OperatorSupportsFusedActivation(OperatorType type) { switch (type) { - case OperatorType::kConcatenation: - case OperatorType::kFakeQuant: - case OperatorType::kGather: - case OperatorType::kSlice: - case OperatorType::kSqueeze: - case OperatorType::kTensorFlowReshape: - case OperatorType::kTensorFlowSplit: - return false; - default: + case OperatorType::kAdd: + case OperatorType::kAveragePool: + case OperatorType::kBatchNormalization: + case OperatorType::kConv: + case OperatorType::kDepthwiseConv: + case OperatorType::kDiv: + case OperatorType::kFullyConnected: + case OperatorType::kL2Pool: + case OperatorType::kMaxPool: + case OperatorType::kMul: + case OperatorType::kSub: return true; + default: + return false; } } @@ -583,6 +589,13 @@ void UnextendShape(Shape* shape, int new_shape_size) { shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction); } +bool IsValid(const Shape& shape) { + for (int i = 0; i < shape.dimensions_count(); ++i) { + if (shape.dims(i) < 1) return false; + } + return true; +} + void CheckShapeDimensions(const Shape& shape) { for (int i = 0; i < shape.dimensions_count(); ++i) { CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i @@ -1863,18 +1876,15 @@ void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order, output_axes_order == AxesOrder::kHWIO) { // 3210 <- 3210 // HWIO <- OHWI - (*shuffle)[0] = 1; - (*shuffle)[1] = 2; - (*shuffle)[2] = 3; - (*shuffle)[3] = 0; + *shuffle = {1, 2, 3, 0}; } else if (input_axes_order == AxesOrder::kHWIO && output_axes_order == AxesOrder::kOHWI) { // 3210 <- 3210 // OHWI <- HWIO - (*shuffle)[0] = 3; - (*shuffle)[1] = 0; - (*shuffle)[2] = 1; - (*shuffle)[3] = 2; + *shuffle = {3, 0, 1, 2}; + } else if (input_axes_order == AxesOrder::kOHWI && + output_axes_order == AxesOrder::kHWOI) { + *shuffle = {1, 2, 0, 3}; } else { LOG(FATAL) << "Bad shuffle"; } @@ -2020,6 +2030,8 @@ int AxesCount(AxesOrder axes_order) { return 4; case AxesOrder::kNHWC: return 4; + case AxesOrder::kHWOI: + return 4; default: LOG(FATAL) << "Bad AxesOrder"; return 0; diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index 3b320e801349595396e573e225ffacf4c7607e52..791ced8d012209867f0ce7ce417d4d11b59b2ead 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -32,8 +32,9 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/runtime/types.h" #include "tensorflow/contrib/lite/toco/toco_flags.pb.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/contrib/lite/toco/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" // TODO(aselle): Replace with using a container specific hash override instead. namespace std { @@ -100,6 +101,8 @@ std::vector>::iterator FindOp(Model& model, const char* OperatorTypeName(OperatorType type); string HelpfulOperatorTypeName(const Operator& op); +// Whether the operator can be fused with an activation function. Note that this +// will return false by default for new operators; fusing support is opt-in. bool OperatorSupportsFusedActivation(OperatorType type); void DumpGraphvizVideoFrame(const Model& model); @@ -112,7 +115,9 @@ void ExtendShape(Shape* shape, int new_shape_size); // TODO(b/36075966): Clean up when dims superseded by array shape. void UnextendShape(Shape* shape, int new_shape_size); -// Checks (using CHECK) that all dimensions of 'shape' are at least 1. +// Checks that all dimensions of 'shape' are at least 1. +bool IsValid(const Shape& shape); +// Same as above, but reports error using CHECK. void CheckShapeDimensions(const Shape& shape); // Given two shapes with potentially different dimensionality and dimension @@ -315,7 +320,7 @@ void UseArraysExtraInfo(Model* model, bool quantize_output); // doesn't have enough range to represent the sum of elements, an error is // returned. template -port::Status NumElements(const std::vector& shape, U* num_elements) { +tensorflow::Status NumElements(const std::vector& shape, U* num_elements) { static_assert( std::numeric_limits::max() <= std::numeric_limits::max(), "vector type exceed capabilities of NumElements"); @@ -326,17 +331,17 @@ port::Status NumElements(const std::vector& shape, U* num_elements) { // TensorFlow's shapes sometimes include -1 to represent an "unknown" // size but TOCO isn't able to create arrays of unknown sizes and will // crash in RequiredBufferSizeForShape(). - return port::Status(false, - "Tensor shape should not include negative values"); + return tensorflow::errors::InvalidArgument( + "Tensor shape should not include negative values"); } if (static_cast(dim) > std::numeric_limits::max() / *num_elements) { *num_elements = 0; - return port::Status(false, "Tensor shape is too large"); + return tensorflow::errors::InvalidArgument("Tensor shape is too large"); } *num_elements *= dim; } - return port::Status::OK(); + return tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/tooling_util_test.cc b/tensorflow/contrib/lite/toco/tooling_util_test.cc index 87fd30db2cf54824a3c34ed875291d898f1a9e38..8609e5beddd200be4e5ebfe1fb2a79048e0e60ab 100644 --- a/tensorflow/contrib/lite/toco/tooling_util_test.cc +++ b/tensorflow/contrib/lite/toco/tooling_util_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/lib/core/status.h" namespace toco { @@ -99,7 +100,7 @@ static const char kLargeTensorMessage[] = "Tensor shape is too large"; TEST(NumElementsTest, Int) { int count; - port::Status status = port::Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); status = NumElements(std::vector{1024, 1024, 2047}, &count); EXPECT_TRUE(status.ok()); @@ -114,7 +115,7 @@ TEST(NumElementsTest, Int) { TEST(NumElementsTest, Int32) { int32_t count; - port::Status status = port::Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); status = NumElements(std::vector{1024, 1024, 2047}, &count); EXPECT_TRUE(status.ok()); @@ -129,7 +130,7 @@ TEST(NumElementsTest, Int32) { TEST(NumElementsTest, Int64) { int64_t count; - port::Status status = port::Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); status = NumElements(std::vector{16777216, 16777216, 32767}, &count); EXPECT_TRUE(status.ok()); @@ -144,7 +145,7 @@ TEST(NumElementsTest, Int64) { TEST(NumElementsTest, UnsignedInt32) { uint32_t count; - port::Status status = port::Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); status = NumElements(std::vector{1024, 2048, 2047}, &count); EXPECT_TRUE(status.ok()); @@ -159,7 +160,7 @@ TEST(NumElementsTest, UnsignedInt32) { TEST(NumElementsTest, UnsignedInt64) { uint64_t count; - port::Status status = port::Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); status = NumElements(std::vector{16777216, 16777216, 65535}, &count); @@ -174,4 +175,10 @@ TEST(NumElementsTest, UnsignedInt64) { EXPECT_EQ(status.error_message(), kLargeTensorMessage); } +TEST(FusedActivationTest, DefaultsToUnfused) { + EXPECT_TRUE(OperatorSupportsFusedActivation(OperatorType::kAdd)); + EXPECT_FALSE(OperatorSupportsFusedActivation(OperatorType::kNone)); + EXPECT_FALSE(OperatorSupportsFusedActivation(static_cast(255))); +} + } // namespace toco diff --git a/tensorflow/contrib/lite/tools/benchmark/BUILD b/tensorflow/contrib/lite/tools/benchmark/BUILD index 4824a4dbdef6f52a63e712a7a79d2e8f3cec616a..183a545295f690decec47f1c31aa473667408a3d 100644 --- a/tensorflow/contrib/lite/tools/benchmark/BUILD +++ b/tensorflow/contrib/lite/tools/benchmark/BUILD @@ -5,8 +5,10 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts") +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") -common_copts = ["-Wall"] +common_copts = ["-Wall"] + tflite_copts() cc_binary( name = "benchmark_model", @@ -15,13 +17,10 @@ cc_binary( "logging.h", ], copts = common_copts, - linkopts = select({ + linkopts = tflite_linkopts() + select({ "//tensorflow:android": [ - "-pie", - "-landroid", - "-lm", - "-z defs", - "-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export + "-pie", # Android 5.0 and later supports only PIE + "-lm", # some builtin ops, e.g., tanh, need -lm ], "//conditions:default": [], }), @@ -35,7 +34,6 @@ cc_library( srcs = ["command_line_flags.cc"], hdrs = ["command_line_flags.h"], copts = common_copts, - visibility = ["//visibility:private"], ) cc_test( @@ -68,6 +66,16 @@ cc_library( ], ) +cc_library( + name = "benchmark_params", + srcs = [ + "benchmark_params.cc", + "logging.h", + ], + hdrs = ["benchmark_params.h"], + copts = common_copts, +) + cc_library( name = "benchmark_model_lib", srcs = [ @@ -77,6 +85,7 @@ cc_library( hdrs = ["benchmark_model.h"], copts = common_copts, deps = [ + ":benchmark_params", ":command_line_flags", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:string_util", diff --git a/tensorflow/contrib/lite/tools/benchmark/README.md b/tensorflow/contrib/lite/tools/benchmark/README.md index e6f333aa5bb11449d5bf5d6c60cf77088649df8c..c10826afff6d5569545d4b7df73c88d24d9dcd1a 100644 --- a/tensorflow/contrib/lite/tools/benchmark/README.md +++ b/tensorflow/contrib/lite/tools/benchmark/README.md @@ -46,8 +46,6 @@ adb shell /data/local/tmp/benchmark_model \ --graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \ --input_layer="Placeholder" \ --input_layer_shape="1,224,224,3" \ - --input_layer_type="uint8" \ - --output_layer="MobilenetV1/Predictions/Reshape_1" \ --num_threads=4 ``` @@ -66,8 +64,6 @@ bazel-bin/tensorflow/contrib/lite/tools/benchmark/benchmark_model \ --graph=mobilenet_quant_v1_224.tflite \ --input_layer="Placeholder" \ --input_layer_shape="1,224,224,3" \ - --input_layer_type="uint8" \ - --output_layer="MobilenetV1/Predictions/Reshape_1" \ --num_threads=4 ``` @@ -93,80 +89,66 @@ This compiles TFLite with profiling enabled, now you can run the benchmark binar ============================== Run Order ============================== [node type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [times called] [Name] - CONV_2D 0.000 9.132 9.132 0.121% 0.121% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_0/Relu6] - DEPTHWISE_CONV_2D 9.135 3.280 3.280 0.043% 0.165% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_depthwise/Relu6] - CONV_2D 12.419 6.877 6.877 0.091% 0.256% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6] - DEPTHWISE_CONV_2D 19.299 1.708 1.708 0.023% 0.278% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_2_depthwise/Relu6] - CONV_2D 21.012 4.162 4.162 0.055% 0.334% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_2_pointwise/Relu6] - DEPTHWISE_CONV_2D 25.177 3.520 3.520 0.047% 0.380% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_depthwise/Relu6] - CONV_2D 28.701 10.218 10.218 0.136% 0.516% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6] - DEPTHWISE_CONV_2D 38.922 0.827 0.827 0.011% 0.527% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_4_depthwise/Relu6] - CONV_2D 39.752 1.401 1.401 0.019% 0.545% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_4_pointwise/Relu6] - DEPTHWISE_CONV_2D 41.156 1.290 1.290 0.017% 0.563% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_depthwise/Relu6] - CONV_2D 42.448 5.995 5.995 0.080% 0.642% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6] - DEPTHWISE_CONV_2D 48.445 0.409 0.409 0.005% 0.647% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_depthwise/Relu6] - CONV_2D 48.856 6.167 6.167 0.082% 0.729% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Relu6] - DEPTHWISE_CONV_2D 55.026 0.629 0.629 0.008% 0.738% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_depthwise/Relu6] - CONV_2D 55.656 6.464 6.464 0.086% 0.823% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6] - DEPTHWISE_CONV_2D 62.124 0.647 0.647 0.009% 0.832% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_depthwise/Relu6] - CONV_2D 62.774 14.666 14.666 0.195% 1.026% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6] - DEPTHWISE_CONV_2D 77.444 0.635 0.635 0.008% 1.035% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_depthwise/Relu6] - CONV_2D 78.081 7.186 7.186 0.095% 1.130% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6] - DEPTHWISE_CONV_2D 85.270 0.646 0.646 0.009% 1.139% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_depthwise/Relu6] - CONV_2D 85.918 9.529 9.529 0.126% 1.265% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6] - DEPTHWISE_CONV_2D 95.451 0.628 0.628 0.008% 1.273% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_depthwise/Relu6] - CONV_2D 96.081 2.077 2.077 0.028% 1.301% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6] - DEPTHWISE_CONV_2D 98.162 0.168 0.168 0.002% 1.303% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_12_depthwise/Relu6] - CONV_2D 98.332 1.007 1.007 0.013% 1.317% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_12_pointwise/Relu6] - DEPTHWISE_CONV_2D 99.342 0.288 0.288 0.004% 1.320% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_depthwise/Relu6] - CONV_2D 99.632 8.197 8.197 0.109% 1.429% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6] - AVERAGE_POOL_2D 107.832 0.045 0.045 0.001% 1.430% 0.000 0 [MobilenetV1/Logits/AvgPool_1a/AvgPool] - CONV_2D 107.878 0.325 0.325 0.004% 1.434% 0.000 0 [MobilenetV1/Logits/Conv2d_1c_1x1/BiasAdd] - RESHAPE 108.206 0.003 0.003 0.000% 1.434% 0.000 0 [MobilenetV1/Predictions/Reshape] - SOFTMAX 108.211 0.038 0.038 0.001% 1.434% 0.000 0 [MobilenetV1/Predictions/Softmax] + CONV_2D 0.000 4.269 4.269 0.107% 0.107% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_0/Relu6] + DEPTHWISE_CONV_2D 4.270 2.150 2.150 0.054% 0.161% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_depthwise/Relu6] + CONV_2D 6.421 6.107 6.107 0.153% 0.314% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6] + DEPTHWISE_CONV_2D 12.528 1.366 1.366 0.034% 0.348% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_2_depthwise/Relu6] + CONV_2D 13.895 4.195 4.195 0.105% 0.454% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_2_pointwise/Relu6] + DEPTHWISE_CONV_2D 18.091 1.260 1.260 0.032% 0.485% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_depthwise/Relu6] + CONV_2D 19.352 6.652 6.652 0.167% 0.652% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6] + DEPTHWISE_CONV_2D 26.005 0.698 0.698 0.018% 0.670% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_4_depthwise/Relu6] + CONV_2D 26.703 3.344 3.344 0.084% 0.754% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_4_pointwise/Relu6] + DEPTHWISE_CONV_2D 30.047 0.646 0.646 0.016% 0.770% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_depthwise/Relu6] + CONV_2D 30.694 5.800 5.800 0.145% 0.915% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6] + DEPTHWISE_CONV_2D 36.495 0.331 0.331 0.008% 0.924% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_depthwise/Relu6] + CONV_2D 36.826 2.838 2.838 0.071% 0.995% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Relu6] + DEPTHWISE_CONV_2D 39.665 0.439 0.439 0.011% 1.006% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_depthwise/Relu6] + CONV_2D 40.105 5.293 5.293 0.133% 1.139% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6] + DEPTHWISE_CONV_2D 45.399 0.352 0.352 0.009% 1.147% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_depthwise/Relu6] + CONV_2D 45.752 5.322 5.322 0.133% 1.281% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6] + DEPTHWISE_CONV_2D 51.075 0.357 0.357 0.009% 1.290% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_depthwise/Relu6] + CONV_2D 51.432 5.693 5.693 0.143% 1.433% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6] + DEPTHWISE_CONV_2D 57.126 0.366 0.366 0.009% 1.442% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_depthwise/Relu6] + CONV_2D 57.493 5.472 5.472 0.137% 1.579% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6] + DEPTHWISE_CONV_2D 62.966 0.364 0.364 0.009% 1.588% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_depthwise/Relu6] + CONV_2D 63.330 5.404 5.404 0.136% 1.724% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6] + DEPTHWISE_CONV_2D 68.735 0.155 0.155 0.004% 1.728% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_12_depthwise/Relu6] + CONV_2D 68.891 2.970 2.970 0.074% 1.802% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_12_pointwise/Relu6] + DEPTHWISE_CONV_2D 71.862 0.206 0.206 0.005% 1.807% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_depthwise/Relu6] + CONV_2D 72.069 5.888 5.888 0.148% 1.955% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6] + AVERAGE_POOL_2D 77.958 0.036 0.036 0.001% 1.956% 0.000 0 [MobilenetV1/Logits/AvgPool_1a/AvgPool] + CONV_2D 77.994 1.445 1.445 0.036% 1.992% 0.000 0 [MobilenetV1/Logits/Conv2d_1c_1x1/BiasAdd] + RESHAPE 79.440 0.002 0.002 0.000% 1.992% 0.000 0 [MobilenetV1/Predictions/Reshape] + SOFTMAX 79.443 0.029 0.029 0.001% 1.993% 0.000 0 [MobilenetV1/Predictions/Softmax] ============================== Top by Computation Time ============================== [node type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [times called] [Name] - CONV_2D 62.774 14.666 14.666 0.195% 0.195% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6] - CONV_2D 28.701 10.218 10.218 0.136% 0.330% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6] - CONV_2D 85.918 9.529 9.529 0.126% 0.456% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6] - CONV_2D 0.000 9.132 9.132 0.121% 0.578% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_0/Relu6] - CONV_2D 99.632 8.197 8.197 0.109% 0.686% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6] - CONV_2D 78.081 7.186 7.186 0.095% 0.782% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6] - CONV_2D 12.419 6.877 6.877 0.091% 0.873% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6] - CONV_2D 55.656 6.464 6.464 0.086% 0.958% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6] - CONV_2D 48.856 6.167 6.167 0.082% 1.040% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Relu6] - CONV_2D 42.448 5.995 5.995 0.080% 1.120% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6] - -============================== Top by Memory Use ============================== - [node type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [times called] [Name] - SOFTMAX 108.211 0.038 0.038 0.001% 0.001% 0.000 0 [MobilenetV1/Predictions/Softmax] - RESHAPE 108.206 0.003 0.003 0.000% 0.001% 0.000 0 [MobilenetV1/Predictions/Reshape] - CONV_2D 78.081 7.186 7.186 0.095% 0.096% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6] - DEPTHWISE_CONV_2D 77.444 0.635 0.635 0.008% 0.104% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_depthwise/Relu6] - CONV_2D 62.774 14.666 14.666 0.195% 0.299% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6] - DEPTHWISE_CONV_2D 62.124 0.647 0.647 0.009% 0.307% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_depthwise/Relu6] - CONV_2D 55.656 6.464 6.464 0.086% 0.393% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6] - DEPTHWISE_CONV_2D 55.026 0.629 0.629 0.008% 0.401% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_depthwise/Relu6] - CONV_2D 48.856 6.167 6.167 0.082% 0.483% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Relu6] - DEPTHWISE_CONV_2D 48.445 0.409 0.409 0.005% 0.489% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_depthwise/Relu6] + CONV_2D 19.352 6.652 6.652 0.167% 0.167% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6] + CONV_2D 6.421 6.107 6.107 0.153% 0.320% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6] + CONV_2D 72.069 5.888 5.888 0.148% 0.468% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6] + CONV_2D 30.694 5.800 5.800 0.145% 0.613% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6] + CONV_2D 51.432 5.693 5.693 0.143% 0.756% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6] + CONV_2D 57.493 5.472 5.472 0.137% 0.893% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6] + CONV_2D 63.330 5.404 5.404 0.136% 1.029% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6] + CONV_2D 45.752 5.322 5.322 0.133% 1.162% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6] + CONV_2D 40.105 5.293 5.293 0.133% 1.295% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6] + CONV_2D 0.000 4.269 4.269 0.107% 1.402% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_0/Relu6] Number of nodes executed: 31 ============================== Summary by node type ============================== [Node type] [count] [avg ms] [avg %] [cdf %] [mem KB] [times called] - CONV_2D 15 1.861 86.679% 86.679% 0.000 0 - DEPTHWISE_CONV_2D 13 0.286 13.321% 100.000% 0.000 0 + CONV_2D 15 1.406 89.270% 89.270% 0.000 0 + DEPTHWISE_CONV_2D 13 0.169 10.730% 100.000% 0.000 0 SOFTMAX 1 0.000 0.000% 100.000% 0.000 0 RESHAPE 1 0.000 0.000% 100.000% 0.000 0 AVERAGE_POOL_2D 1 0.000 0.000% 100.000% 0.000 0 -Timings (microseconds): count=50 first=108164 curr=128308 min=102850 max=197072 avg=150805 std=24368 +Timings (microseconds): count=50 first=79449 curr=81350 min=77385 max=88213 avg=79732 std=1929 Memory (bytes): count=0 31 nodes observed -Average inference timings in us: Warmup: 135310, Init: 12123, no stats: 150988 - +Average inference timings in us: Warmup: 83235, Init: 38467, no stats: 79760.9 ``` diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc index a8a9a6112c1ec050be8d0bcfe9dc5f00df40d3ff..08648bcfe26365d180d984fde8f8e04b22eb45dd 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc @@ -48,6 +48,19 @@ namespace tflite { namespace benchmark { using tensorflow::Stat; +BenchmarkParams BenchmarkModel::DefaultParams() { + BenchmarkParams params; + params.AddParam("num_runs", BenchmarkParam::Create(50)); + params.AddParam("run_delay", BenchmarkParam::Create(-1.0f)); + params.AddParam("num_threads", BenchmarkParam::Create(1)); + params.AddParam("benchmark_name", BenchmarkParam::Create("")); + params.AddParam("output_prefix", BenchmarkParam::Create("")); + params.AddParam("warmup_runs", BenchmarkParam::Create(1)); + return params; +} + +BenchmarkModel::BenchmarkModel() : params_(DefaultParams()) {} + void BenchmarkLoggingListener::OnBenchmarkEnd(const BenchmarkResults &results) { auto inference_us = results.inference_time_us(); auto init_us = results.startup_latency_us(); @@ -60,24 +73,29 @@ void BenchmarkLoggingListener::OnBenchmarkEnd(const BenchmarkResults &results) { std::vector BenchmarkModel::GetFlags() { return { - Flag("num_runs", ¶ms_.num_runs, "number of runs"), - Flag("run_delay", ¶ms_.run_delay, "delay between runs in seconds"), - Flag("num_threads", ¶ms_.num_threads, "number of threads"), - Flag("benchmark_name", ¶ms_.benchmark_name, "benchmark name"), - Flag("output_prefix", ¶ms_.output_prefix, "benchmark output prefix"), - Flag("warmup_runs", ¶ms_.warmup_runs, - "how many runs to initialize model"), + CreateFlag("num_runs", ¶ms_, "number of runs"), + CreateFlag("run_delay", ¶ms_, "delay between runs in seconds"), + CreateFlag("num_threads", ¶ms_, "number of threads"), + CreateFlag("benchmark_name", ¶ms_, "benchmark name"), + CreateFlag("output_prefix", ¶ms_, + "benchmark output prefix"), + CreateFlag("warmup_runs", ¶ms_, + "how many runs to initialize model"), }; } void BenchmarkModel::LogFlags() { - TFLITE_LOG(INFO) << "Num runs: [" << params_.num_runs << "]"; - TFLITE_LOG(INFO) << "Inter-run delay (seconds): [" << params_.run_delay + TFLITE_LOG(INFO) << "Num runs: [" << params_.Get("num_runs") << "]"; + TFLITE_LOG(INFO) << "Inter-run delay (seconds): [" + << params_.Get("run_delay") << "]"; + TFLITE_LOG(INFO) << "Num threads: [" << params_.Get("num_threads") + << "]"; + TFLITE_LOG(INFO) << "Benchmark name: [" + << params_.Get("benchmark_name") << "]"; + TFLITE_LOG(INFO) << "Output prefix: [" + << params_.Get("output_prefix") << "]"; + TFLITE_LOG(INFO) << "Warmup runs: [" << params_.Get("warmup_runs") << "]"; - TFLITE_LOG(INFO) << "Num threads: [" << params_.num_threads << "]"; - TFLITE_LOG(INFO) << "Benchmark name: [" << params_.benchmark_name << "]"; - TFLITE_LOG(INFO) << "Output prefix: [" << params_.output_prefix << "]"; - TFLITE_LOG(INFO) << "Warmup runs: [" << params_.warmup_runs << "]"; } Stat BenchmarkModel::Run(int num_times, RunType run_type) { @@ -91,7 +109,7 @@ Stat BenchmarkModel::Run(int num_times, RunType run_type) { listeners_.OnSingleRunEnd(); run_stats.UpdateStat(end_us - start_us); - SleepForSeconds(params_.run_delay); + SleepForSeconds(params_.Get("run_delay")); } std::stringstream stream; @@ -117,8 +135,10 @@ void BenchmarkModel::Run(int argc, char **argv) { << "ms"; uint64_t input_bytes = ComputeInputBytes(); - Stat warmup_time_us = Run(params_.warmup_runs, WARMUP); - Stat inference_time_us = Run(params_.num_runs, REGULAR); + Stat warmup_time_us = + Run(params_.Get("warmup_runs"), WARMUP); + Stat inference_time_us = + Run(params_.Get("num_runs"), REGULAR); listeners_.OnBenchmarkEnd( {startup_latency_us, input_bytes, warmup_time_us, inference_time_us}); } diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h index d48f693693c2cee0cd2e2a6f2b4c590998feffb3..942e21f67a7f864f16b7b1b85b2599d5c872b5c7 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_params.h" #include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h" #include "tensorflow/core/util/stats_calculator.h" @@ -63,17 +64,6 @@ class BenchmarkResults { tensorflow::Stat inference_time_us_; }; -struct BenchmarkParams { - BenchmarkParams() - : num_runs(50), warmup_runs(1), run_delay(-1.0), num_threads(1) {} - int num_runs; - int warmup_runs; - float run_delay; - int num_threads; - std::string benchmark_name; - std::string output_prefix; -}; - class BenchmarkListener { public: virtual void OnBenchmarkStart(const BenchmarkParams& params) {} @@ -130,12 +120,22 @@ class BenchmarkLoggingListener : public BenchmarkListener { void OnBenchmarkEnd(const BenchmarkResults& results) override; }; +template +Flag CreateFlag(const char* name, BenchmarkParams* params, + const std::string& usage) { + return Flag(name, [params, name](const T& val) { params->Set(name, val); }, + params->Get(name), usage); +} + // Benchmarks a model. // // Subclasses need to implement initialization and running of the model. // The results can be collected by adding BenchmarkListener(s). class BenchmarkModel { public: + static BenchmarkParams DefaultParams(); + BenchmarkModel(); + BenchmarkModel(BenchmarkParams params) : params_(std::move(params)) {} virtual ~BenchmarkModel() {} bool ParseFlags(int argc, char** argv); virtual void Init() = 0; diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_params.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.cc new file mode 100644 index 0000000000000000000000000000000000000000..1dcf580a9d4995e6cb3706d3562bc8a2f4670082 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.cc @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_params.h" + +#include +#include +#include + +#include "tensorflow/contrib/lite/tools/benchmark/logging.h" + +namespace tflite { +namespace benchmark { + +void BenchmarkParam::AssertHasSameType(BenchmarkParam::ParamType a, + BenchmarkParam::ParamType b) { + TFLITE_BENCHMARK_CHECK(a == b) << "Type mismatch while accessing parameter."; +} + +template <> +BenchmarkParam::ParamType BenchmarkParam::GetValueType() { + return BenchmarkParam::ParamType::TYPE_INT32; +} + +template <> +BenchmarkParam::ParamType BenchmarkParam::GetValueType() { + return BenchmarkParam::ParamType::TYPE_BOOL; +} + +template <> +BenchmarkParam::ParamType BenchmarkParam::GetValueType() { + return BenchmarkParam::ParamType::TYPE_FLOAT; +} + +template <> +BenchmarkParam::ParamType BenchmarkParam::GetValueType() { + return BenchmarkParam::ParamType::TYPE_STRING; +} + +void BenchmarkParams::AssertParamExists(const std::string& name) const { + TFLITE_BENCHMARK_CHECK(HasParam(name)) << name << " was not found."; +} + +} // namespace benchmark +} // namespace tflite diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_params.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.h new file mode 100644 index 0000000000000000000000000000000000000000..33448dd1623577fdfda6316c588cc60ccbaa1994 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.h @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/tools/benchmark/logging.h" + +namespace tflite { +namespace benchmark { + +template +class TypedBenchmarkParam; + +class BenchmarkParam { + protected: + enum class ParamType { TYPE_INT32, TYPE_FLOAT, TYPE_BOOL, TYPE_STRING }; + + public: + template + static std::unique_ptr Create(const T& default_value) { + return std::unique_ptr( + new TypedBenchmarkParam(default_value)); + } + + template + TypedBenchmarkParam* AsTyped() { + AssertHasSameType(GetValueType(), type_); + return static_cast*>(this); + } + virtual ~BenchmarkParam() {} + BenchmarkParam(ParamType type) : type_(type) {} + + private: + static void AssertHasSameType(ParamType a, ParamType b); + template + static ParamType GetValueType(); + + const ParamType type_; +}; + +template +class TypedBenchmarkParam : public BenchmarkParam { + public: + TypedBenchmarkParam(const T& value) + : BenchmarkParam(GetValueType()), value_(value) {} + void Set(const T& value) { value_ = value; } + + T Get() { return value_; } + + private: + T value_; +}; + +class BenchmarkParams { + public: + void AddParam(const std::string& name, + std::unique_ptr value) { + params_[name] = std::move(value); + } + + bool HasParam(const std::string& name) const { + return params_.find(name) != params_.end(); + } + + template + void Set(const std::string& name, const T& value) { + AssertParamExists(name); + params_.at(name)->AsTyped()->Set(value); + } + + template + T Get(const std::string& name) const { + AssertParamExists(name); + return params_.at(name)->AsTyped()->Get(); + } + + private: + void AssertParamExists(const std::string& name) const; + std::unordered_map> params_; +}; + +} // namespace benchmark +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_ diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc index 2e5b86627322c2c64b8ef665a91595174a5dd8dd..73affc26b034f415ae2a2101e0b558cdb94d8d5b 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc @@ -123,29 +123,11 @@ void FillRandomString(tflite::DynamicBuffer* buffer, } } -TfLiteType TfLiteTypeFromString(const string& input_layer_type) { - if (input_layer_type == "string") - return kTfLiteString; - else if (input_layer_type == "float") - return kTfLiteFloat32; - else if (input_layer_type == "uint8") - return kTfLiteUInt8; - else if (input_layer_type == "int32") - return kTfLiteInt32; - else if (input_layer_type == "int64") - return kTfLiteInt64; - else - return kTfLiteNoType; -} - bool PopulateInputLayerInfo( const string& names_string, const string& shapes_string, - const string& types_string, const string& values_string, std::vector* info) { std::vector names = Split(names_string, ','); std::vector shapes = Split(shapes_string, ':'); - std::vector types = Split(types_string, ','); - std::vector values = Split(values_string, ':'); if (names.size() != shapes.size()) { TFLITE_LOG(ERROR) << "The number of items in" @@ -158,17 +140,6 @@ bool PopulateInputLayerInfo( << " --input_layer_shape=1,224,224,4:1,20"; return false; } - if (names.size() != types.size()) { - TFLITE_LOG(ERROR) << "The number of items in" - << " --input_layer_type (" << types_string << ", with " - << types.size() << " items)" - << " must match the number of items in" - << " --input_layer (" << names_string << ", with " - << names.size() << " items)." - << " For example --input_layer=input1,input2" - << " --input_layer_type=float,int"; - return false; - } for (int i = 0; i < names.size(); ++i) { info->push_back(BenchmarkTfLiteModel::InputLayerInfo()); @@ -176,10 +147,6 @@ bool PopulateInputLayerInfo( input.name = names[i]; - input.data_type = TfLiteTypeFromString(types[i]); - TFLITE_BENCHMARK_CHECK(input.data_type != kTfLiteNoType) - << types[i] << " was an invalid type"; - TFLITE_BENCHMARK_CHECK(SplitAndParse(shapes[i], ',', &input.shape)) << "Incorrect size string specified: " << shapes[i]; for (int dim : input.shape) { @@ -190,30 +157,42 @@ bool PopulateInputLayerInfo( return false; } } - - if (i < values.size()) { - TFLITE_BENCHMARK_CHECK( - SplitAndParse(values[i], ',', &input.initialization_values)) - << "Incorrect initialization values string specified: " << values[i]; - } } return true; } +BenchmarkParams GetDefaultParams() { + BenchmarkParams default_params = BenchmarkModel::DefaultParams(); + default_params.AddParam("graph", BenchmarkParam::Create("")); + default_params.AddParam("input_layer", + BenchmarkParam::Create("")); + default_params.AddParam("input_layer_shape", + BenchmarkParam::Create("")); + default_params.AddParam("use_nnapi", BenchmarkParam::Create(false)); + return default_params; +} + } // namespace +BenchmarkTfLiteModel::BenchmarkTfLiteModel() + : BenchmarkModel(GetDefaultParams()) { + AddListener(&profiling_listener_); +} + +BenchmarkTfLiteModel::BenchmarkTfLiteModel(BenchmarkParams params) + : BenchmarkModel(std::move(params)) { + AddListener(&profiling_listener_); +} + std::vector BenchmarkTfLiteModel::GetFlags() { std::vector flags = BenchmarkTfLiteModel::BenchmarkModel::GetFlags(); std::vector specific_flags = { - Flag("graph", &graph, "graph file name"), - Flag("input_layer", &input_layer_string, "input layer names"), - Flag("input_layer_shape", &input_layer_shape_string, "input layer shape"), - Flag("input_layer_type", &input_layer_type_string, "input layer type"), - Flag("input_layer_values", &input_layer_values_string, - "values to initialize the inputs with"), - Flag("output_layer", &output_layer_string, "output layer name"), - Flag("use_nnapi", &use_nnapi, "use nnapi api")}; + CreateFlag("graph", ¶ms_, "graph file name"), + CreateFlag("input_layer", ¶ms_, "input layer names"), + CreateFlag("input_layer_shape", ¶ms_, + "input layer shape"), + CreateFlag("use_nnapi", ¶ms_, "use nnapi api")}; flags.insert(flags.end(), specific_flags.begin(), specific_flags.end()); return flags; @@ -221,23 +200,23 @@ std::vector BenchmarkTfLiteModel::GetFlags() { void BenchmarkTfLiteModel::LogFlags() { BenchmarkModel::LogFlags(); - TFLITE_LOG(INFO) << "Graph: [" << graph << "]"; - TFLITE_LOG(INFO) << "Input layers: [" << input_layer_string << "]"; - TFLITE_LOG(INFO) << "Input shapes: [" << input_layer_shape_string << "]"; - TFLITE_LOG(INFO) << "Input types: [" << input_layer_type_string << "]"; - TFLITE_LOG(INFO) << "Output layers: [" << output_layer_string << "]"; - TFLITE_LOG(INFO) << "Use nnapi : [" << use_nnapi << "]"; + TFLITE_LOG(INFO) << "Graph: [" << params_.Get("graph") << "]"; + TFLITE_LOG(INFO) << "Input layers: [" + << params_.Get("input_layer") << "]"; + TFLITE_LOG(INFO) << "Input shapes: [" + << params_.Get("input_layer_shape") << "]"; + TFLITE_LOG(INFO) << "Use nnapi : [" << params_.Get("use_nnapi") << "]"; } bool BenchmarkTfLiteModel::ValidateFlags() { - if (graph.empty()) { + if (params_.Get("graph").empty()) { TFLITE_LOG(ERROR) << "Please specify the name of your TF Lite input file with --graph"; return false; } - return PopulateInputLayerInfo(input_layer_string, input_layer_shape_string, - input_layer_type_string, - input_layer_values_string, &inputs); + return PopulateInputLayerInfo(params_.Get("input_layer"), + params_.Get("input_layer_shape"), + &inputs); } uint64_t BenchmarkTfLiteModel::ComputeInputBytes() { @@ -251,6 +230,7 @@ uint64_t BenchmarkTfLiteModel::ComputeInputBytes() { } void BenchmarkTfLiteModel::Init() { + std::string graph = params_.Get("graph"); model = tflite::FlatBufferModel::BuildFromFile(graph.c_str()); if (!model) { TFLITE_LOG(FATAL) << "Failed to mmap model " << graph; @@ -272,10 +252,14 @@ void BenchmarkTfLiteModel::Init() { } profiling_listener_.SetInterpreter(interpreter.get()); - if (params_.num_threads != -1) { - interpreter->SetNumThreads(params_.num_threads); + const int32_t num_threads = params_.Get("num_threads"); + + if (num_threads != -1) { + interpreter->SetNumThreads(num_threads); } + bool use_nnapi = params_.Get("use_nnapi"); + interpreter->UseNNAPI(use_nnapi); auto interpreter_inputs = interpreter->inputs(); @@ -293,8 +277,6 @@ void BenchmarkTfLiteModel::Init() { TFLITE_BENCHMARK_CHECK_EQ(t->name, input.name) << "Tensor # " << i << " is named " << t->name << " but flags call it " << input.name; - TFLITE_BENCHMARK_CHECK_EQ(t->type, input.data_type) - << "Could not match the type of input tensor " << t->name; } // Resize all non-string tensors. diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h index e70f6de1bf461f4e946ec83d8eea83ff4a15bfca..50cc3f24b3bd2f31555eac69ff208fa2480449b9 100644 --- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h @@ -50,9 +50,8 @@ class ProfilingListener : public BenchmarkListener { // Benchmarks a TFLite model by running tflite interpreter. class BenchmarkTfLiteModel : public BenchmarkModel { public: - BenchmarkTfLiteModel() : use_nnapi(false) { - AddListener(&profiling_listener_); - } + BenchmarkTfLiteModel(); + BenchmarkTfLiteModel(BenchmarkParams params); std::vector GetFlags() override; void LogFlags() override; @@ -64,23 +63,13 @@ class BenchmarkTfLiteModel : public BenchmarkModel { struct InputLayerInfo { std::string name; - TfLiteType data_type; std::vector shape; - // Note that initialization_values is currently unused. - std::vector initialization_values; }; private: std::unique_ptr model; std::unique_ptr interpreter; - std::string graph; - std::string input_layer_string; - std::string input_layer_type_string; - std::string input_layer_shape_string; - std::string input_layer_values_string; - std::string output_layer_string; std::vector inputs; - bool use_nnapi; ProfilingListener profiling_listener_; }; diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc index 723bf67e03d52d9cbca001162682016504a6b39b..ff818b9dcb5ee0b58b95c3dceae74083dbd4f0da 100644 --- a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc +++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include #include +#include #include namespace tflite { @@ -35,7 +36,7 @@ bool ParseFlag(const std::string& arg, const std::string& flag, if (arg.find(flag_prefix) != 0) { return false; } - bool has_value = (arg.size() >= flag_prefix.size() + 1); + bool has_value = arg.size() >= flag_prefix.size(); *value_parsing_ok = has_value; if (has_value) { *value_parsing_ok = parse_func(arg.substr(flag_prefix.size())); @@ -44,76 +45,79 @@ bool ParseFlag(const std::string& arg, const std::string& flag, } template -bool ParseFlag(const std::string& flag_value, T* value) { +bool ParseFlag(const std::string& flag_value, + const std::function& hook) { std::istringstream stream(flag_value); T read_value; stream >> read_value; if (!stream.eof() && !stream.good()) { return false; } - *value = read_value; + hook(read_value); return true; } -bool ParseBoolFlag(const std::string& flag_value, bool* value) { +bool ParseBoolFlag(const std::string& flag_value, + const std::function& hook) { if (flag_value != "true" && flag_value != "false") { return false; } - *value = (flag_value == "true"); + hook(flag_value == "true"); return true; } - -bool ParseStringFlag(const std::string& flag_value, std::string* value) { - *value = flag_value; - return true; -} - } // namespace -Flag::Flag(const char* name, int32_t* dst, const std::string& usage_text) +Flag::Flag(const char* name, const std::function& hook, + int32_t default_value, const std::string& usage_text) : name_(name), type_(TYPE_INT32), - value_hook_([dst](const std::string& flag_value) { - return ParseFlag(flag_value, dst); + value_hook_([hook](const std::string& flag_value) { + return ParseFlag(flag_value, hook); }), - default_for_display_(ToString(*dst)), + default_for_display_(ToString(default_value)), usage_text_(usage_text) {} -Flag::Flag(const char* name, int64_t* dst, const std::string& usage_text) +Flag::Flag(const char* name, const std::function& hook, + int64_t default_value, const std::string& usage_text) : name_(name), type_(TYPE_INT64), - value_hook_([dst](const std::string& flag_value) { - return ParseFlag(flag_value, dst); + value_hook_([hook](const std::string& flag_value) { + return ParseFlag(flag_value, hook); }), - default_for_display_(ToString(*dst)), + default_for_display_(ToString(default_value)), usage_text_(usage_text) {} -Flag::Flag(const char* name, float* dst, const std::string& usage_text) +Flag::Flag(const char* name, const std::function& hook, + float default_value, const std::string& usage_text) : name_(name), type_(TYPE_FLOAT), - value_hook_([dst](const std::string& flag_value) { - return ParseFlag(flag_value, dst); + value_hook_([hook](const std::string& flag_value) { + return ParseFlag(flag_value, hook); }), - default_for_display_(ToString(*dst)), + default_for_display_(ToString(default_value)), usage_text_(usage_text) {} -Flag::Flag(const char* name, bool* dst, const std::string& usage_text) +Flag::Flag(const char* name, const std::function& hook, + bool default_value, const std::string& usage_text) : name_(name), type_(TYPE_BOOL), - value_hook_([dst](const std::string& flag_value) { - return ParseBoolFlag(flag_value, dst); + value_hook_([hook](const std::string& flag_value) { + return ParseBoolFlag(flag_value, hook); }), - default_for_display_((*dst) ? "true" : "false"), + default_for_display_(default_value ? "true" : "false"), usage_text_(usage_text) {} -Flag::Flag(const char* name, std::string* dst, const std::string& usage_text) +Flag::Flag(const char* name, + const std::function& hook, + const std::string& default_value, const std::string& usage_text) : name_(name), type_(TYPE_STRING), - value_hook_([dst](const std::string& flag_value) { - return ParseStringFlag(flag_value, dst); + value_hook_([hook](const std::string& flag_value) { + hook(flag_value); + return true; }), - default_for_display_(*dst), + default_for_display_(default_value), usage_text_(usage_text) {} bool Flag::Parse(const std::string& arg, bool* value_parsing_ok) const { diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h index 36f9e64767315a317338bc4d2db2ec2d43bee875..2e514ae3ead3b602b8217998ec09177b1e6a2376 100644 --- a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h +++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h @@ -33,10 +33,11 @@ namespace tflite { // int some_int = 10; // bool some_switch = false; // std::string some_name = "something"; +// // std::vector flag_list = { -// Flag("some_int", &some_int, "an integer that affects X"), -// Flag("some_switch", &some_switch, "a bool that affects Y"), -// Flag("some_name", &some_name, "a std::string that affects Z") +// Flag::CreateFlag("some_int", &some_int, "an integer that affects X"), +// Flag::CreateFlag("some_switch", &some_switch, "a bool that affects Y"), +// Flag::CreateFlag("some_name", &some_name, "a string that affects Z") // }; // // Get usage message before ParseFlags() to capture default values. // std::string usage = Flag::Usage(argv[0], flag_list); @@ -63,11 +64,21 @@ namespace tflite { // text, and a pointer to the corresponding variable. class Flag { public: - Flag(const char* name, int32_t* dst, const std::string& usage_text); - Flag(const char* name, int64_t* dst, const std::string& usage_text); - Flag(const char* name, bool* dst, const std::string& usage_text); - Flag(const char* name, std::string* dst, const std::string& usage_text); - Flag(const char* name, float* dst, const std::string& usage_text); + template + static Flag CreateFlag(const char* name, T* val, const char* usage) { + return Flag(name, [val](const T& v) { *val = v; }, *val, usage); + } + + Flag(const char* name, const std::function& hook, + int32_t default_value, const std::string& usage_text); + Flag(const char* name, const std::function& hook, + int64_t default_value, const std::string& usage_text); + Flag(const char* name, const std::function& hook, + float default_value, const std::string& usage_text); + Flag(const char* name, const std::function& hook, + bool default_value, const std::string& usage_text); + Flag(const char* name, const std::function& hook, + const std::string& default_value, const std::string& usage_text); private: friend class Flags; diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc b/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc index 74cf59105bd9073def4b51575903b7c91621e0e2..03da8051099899241fa5241374d754adb1aa93c6 100644 --- a/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc +++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc @@ -34,15 +34,15 @@ TEST(CommandLineFlagsTest, BasicUsage) { "--some_name=somethingelse", "--some_float=42.0"}; int argc = 6; - bool parsed_ok = - Flags::Parse(&argc, reinterpret_cast(argv_strings), - { - Flag("some_int32", &some_int32, "some int32"), - Flag("some_int64", &some_int64, "some int64"), - Flag("some_switch", &some_switch, "some switch"), - Flag("some_name", &some_name, "some name"), - Flag("some_float", &some_float, "some float"), - }); + bool parsed_ok = Flags::Parse( + &argc, reinterpret_cast(argv_strings), + { + Flag::CreateFlag("some_int32", &some_int32, "some int32"), + Flag::CreateFlag("some_int64", &some_int64, "some int64"), + Flag::CreateFlag("some_switch", &some_switch, "some switch"), + Flag::CreateFlag("some_name", &some_name, "some name"), + Flag::CreateFlag("some_float", &some_float, "some float"), + }); EXPECT_EQ(true, parsed_ok); EXPECT_EQ(20, some_int32); @@ -53,13 +53,26 @@ TEST(CommandLineFlagsTest, BasicUsage) { EXPECT_EQ(argc, 1); } +TEST(CommandLineFlagsTest, EmptyStringFlag) { + int argc = 2; + std::string some_string = "invalid"; + const char* argv_strings[] = {"program_name", "--some_string="}; + bool parsed_ok = Flags::Parse( + &argc, reinterpret_cast(argv_strings), + {Flag::CreateFlag("some_string", &some_string, "some string")}); + + EXPECT_EQ(true, parsed_ok); + EXPECT_EQ(some_string, ""); + EXPECT_EQ(argc, 1); +} + TEST(CommandLineFlagsTest, BadIntValue) { int some_int = 10; int argc = 2; const char* argv_strings[] = {"program_name", "--some_int=notanumber"}; bool parsed_ok = Flags::Parse(&argc, reinterpret_cast(argv_strings), - {Flag("some_int", &some_int, "some int")}); + {Flag::CreateFlag("some_int", &some_int, "some int")}); EXPECT_EQ(false, parsed_ok); EXPECT_EQ(10, some_int); @@ -70,9 +83,9 @@ TEST(CommandLineFlagsTest, BadBoolValue) { bool some_switch = false; int argc = 2; const char* argv_strings[] = {"program_name", "--some_switch=notabool"}; - bool parsed_ok = - Flags::Parse(&argc, reinterpret_cast(argv_strings), - {Flag("some_switch", &some_switch, "some switch")}); + bool parsed_ok = Flags::Parse( + &argc, reinterpret_cast(argv_strings), + {Flag::CreateFlag("some_switch", &some_switch, "some switch")}); EXPECT_EQ(false, parsed_ok); EXPECT_EQ(false, some_switch); @@ -85,7 +98,7 @@ TEST(CommandLineFlagsTest, BadFloatValue) { const char* argv_strings[] = {"program_name", "--some_float=notanumber"}; bool parsed_ok = Flags::Parse(&argc, reinterpret_cast(argv_strings), - {Flag("some_float", &some_float, "some float")}); + {Flag::CreateFlag("some_float", &some_float, "some float")}); EXPECT_EQ(false, parsed_ok); EXPECT_NEAR(-23.23f, some_float, 1e-5f); @@ -121,12 +134,13 @@ TEST(CommandLineFlagsTest, UsageString) { std::string some_name = "something"; // Don't test float in this case, because precision is hard to predict and // match against, and we don't want a flakey test. - const string tool_name = "some_tool_name"; - string usage = Flags::Usage(tool_name + " ", - {Flag("some_int", &some_int, "some int"), - Flag("some_int64", &some_int64, "some int64"), - Flag("some_switch", &some_switch, "some switch"), - Flag("some_name", &some_name, "some name")}); + const std::string tool_name = "some_tool_name"; + std::string usage = Flags::Usage( + tool_name + " ", + {Flag::CreateFlag("some_int", &some_int, "some int"), + Flag::CreateFlag("some_int64", &some_int64, "some int64"), + Flag::CreateFlag("some_switch", &some_switch, "some switch"), + Flag::CreateFlag("some_name", &some_name, "some name")}); // Match the usage message, being sloppy about whitespace. const char* expected_usage = " usage: some_tool_name \n" diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/README.md b/tensorflow/contrib/lite/tools/benchmark/ios/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c8d3307e29efaebdc5c309dc7e4262b54d64943f --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/README.md @@ -0,0 +1,43 @@ +# TFLite iOS benchmark app. + +## Description + +An iOS app to benchmark TFLite models. + +The app reads benchmark parameters from a JSON file named `benchmark_params.json` +in its `benchmark_data` directory. Any downloaded models for benchmarking should +also be placed in `benchmark_data` directory. + +The JSON file specifies the name of the model file and other benchmarking +parameters like inputs to the model, type of inputs, number of iterations, +number of threads. The default values in the JSON file are for the +Mobilenet_1.0_224 model +([paper](https://arxiv.org/pdf/1704.04861.pdf), +[tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz)) + +## To build/install/run + +- Follow instructions at [iOS build for TFLite] +(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/ios.md) +to build TFLite. + +Running + +```bash +tensorflow/contrib/lite/build_ios_universal_lib.sh +``` +will also build `tensorflow/contrib/lite/gen/lib/benchmark-lib.a` . + +- Now copy the downloaded model file to `benchmark_data` directory. + +- Modify `benchmark_params.json` change the `input_layer`, `input_layer_shape` +and other benchmark parameters. + +- Change `Build Phases -> Copy Bundle Resources` and add the model file to the +resources that need to be copied. + +- Ensure that `Build Phases -> Link Binary With Library` contains the +`Accelerate framework` and `tensorflow/contrib/lite/gen/lib/benchmark-lib.a`. + +- Now try running the app. The app has a single button that runs the benchmark + on the model and displays results in a text view below. diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark.xcodeproj/project.pbxproj b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark.xcodeproj/project.pbxproj new file mode 100644 index 0000000000000000000000000000000000000000..b908f733d49b56a6b41ebea4185f1fe8c11edc60 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark.xcodeproj/project.pbxproj @@ -0,0 +1,381 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 50; + objects = { + +/* Begin PBXBuildFile section */ + 6FE7579A20D59CE500F01636 /* benchmark_params.json in Resources */ = {isa = PBXBuildFile; fileRef = 6FE7579920D59CE500F01636 /* benchmark_params.json */; }; + 6FE7579D20D5A5E000F01636 /* benchmark-lib.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 6FE7579C20D5A5E000F01636 /* benchmark-lib.a */; }; + 6FE7579F20D5A6A700F01636 /* Accelerate.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 6FE7579E20D5A6A700F01636 /* Accelerate.framework */; }; + 6FE757A120D5AB8100F01636 /* mobilenet_v1_1.0_224.tflite in Resources */ = {isa = PBXBuildFile; fileRef = 6FE757A020D5AB8000F01636 /* mobilenet_v1_1.0_224.tflite */; }; + 6FE93FFD20D592D8008C9FE4 /* AppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 6FE93FFC20D592D8008C9FE4 /* AppDelegate.m */; }; + 6FE9400020D592D8008C9FE4 /* BenchmarkViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 6FE93FFF20D592D8008C9FE4 /* BenchmarkViewController.mm */; }; + 6FE9400320D592D8008C9FE4 /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 6FE9400120D592D8008C9FE4 /* Main.storyboard */; }; + 6FE9400520D592DA008C9FE4 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 6FE9400420D592DA008C9FE4 /* Assets.xcassets */; }; + 6FE9400B20D592DA008C9FE4 /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 6FE9400A20D592DA008C9FE4 /* main.m */; }; +/* End PBXBuildFile section */ + +/* Begin PBXFileReference section */ + 6FE7579920D59CE500F01636 /* benchmark_params.json */ = {isa = PBXFileReference; lastKnownFileType = text.json; path = benchmark_params.json; sourceTree = ""; }; + 6FE7579C20D5A5E000F01636 /* benchmark-lib.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "benchmark-lib.a"; path = "$SRCROOT/../../../../../../../tensorflow/contrib/lite/gen/lib/benchmark-lib.a"; sourceTree = ""; }; + 6FE7579E20D5A6A700F01636 /* Accelerate.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Accelerate.framework; path = System/Library/Frameworks/Accelerate.framework; sourceTree = SDKROOT; }; + 6FE757A020D5AB8000F01636 /* mobilenet_v1_1.0_224.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = mobilenet_v1_1.0_224.tflite; sourceTree = ""; }; + 6FE93FF820D592D8008C9FE4 /* TFLiteBenchmark.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = TFLiteBenchmark.app; sourceTree = BUILT_PRODUCTS_DIR; }; + 6FE93FFB20D592D8008C9FE4 /* AppDelegate.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = ""; }; + 6FE93FFC20D592D8008C9FE4 /* AppDelegate.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = AppDelegate.m; sourceTree = ""; }; + 6FE93FFE20D592D8008C9FE4 /* BenchmarkViewController.h */ = {isa = PBXFileReference; explicitFileType = sourcecode.cpp.h; path = BenchmarkViewController.h; sourceTree = ""; }; + 6FE93FFF20D592D8008C9FE4 /* BenchmarkViewController.mm */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.objcpp; path = BenchmarkViewController.mm; sourceTree = ""; }; + 6FE9400220D592D8008C9FE4 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/Main.storyboard; sourceTree = ""; }; + 6FE9400420D592DA008C9FE4 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; + 6FE9400920D592DA008C9FE4 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = ""; }; + 6FE9400A20D592DA008C9FE4 /* main.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = main.m; sourceTree = ""; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + 6FE93FF520D592D8008C9FE4 /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + 6FE7579F20D5A6A700F01636 /* Accelerate.framework in Frameworks */, + 6FE7579D20D5A5E000F01636 /* benchmark-lib.a in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + 6FE7579820D59C8B00F01636 /* benchmark_data */ = { + isa = PBXGroup; + children = ( + 6FE757A020D5AB8000F01636 /* mobilenet_v1_1.0_224.tflite */, + 6FE7579920D59CE500F01636 /* benchmark_params.json */, + ); + path = benchmark_data; + sourceTree = ""; + }; + 6FE7579B20D5A5E000F01636 /* Frameworks */ = { + isa = PBXGroup; + children = ( + 6FE7579E20D5A6A700F01636 /* Accelerate.framework */, + 6FE7579C20D5A5E000F01636 /* benchmark-lib.a */, + ); + name = Frameworks; + sourceTree = ""; + }; + 6FE93FEF20D592D8008C9FE4 = { + isa = PBXGroup; + children = ( + 6FE93FFA20D592D8008C9FE4 /* TFLiteBenchmark */, + 6FE93FF920D592D8008C9FE4 /* Products */, + 6FE7579B20D5A5E000F01636 /* Frameworks */, + ); + sourceTree = ""; + }; + 6FE93FF920D592D8008C9FE4 /* Products */ = { + isa = PBXGroup; + children = ( + 6FE93FF820D592D8008C9FE4 /* TFLiteBenchmark.app */, + ); + name = Products; + sourceTree = ""; + }; + 6FE93FFA20D592D8008C9FE4 /* TFLiteBenchmark */ = { + isa = PBXGroup; + children = ( + 6FE7579820D59C8B00F01636 /* benchmark_data */, + 6FE93FFB20D592D8008C9FE4 /* AppDelegate.h */, + 6FE93FFC20D592D8008C9FE4 /* AppDelegate.m */, + 6FE93FFE20D592D8008C9FE4 /* BenchmarkViewController.h */, + 6FE93FFF20D592D8008C9FE4 /* BenchmarkViewController.mm */, + 6FE9400120D592D8008C9FE4 /* Main.storyboard */, + 6FE9400420D592DA008C9FE4 /* Assets.xcassets */, + 6FE9400920D592DA008C9FE4 /* Info.plist */, + 6FE9400A20D592DA008C9FE4 /* main.m */, + ); + path = TFLiteBenchmark; + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXNativeTarget section */ + 6FE93FF720D592D8008C9FE4 /* TFLiteBenchmark */ = { + isa = PBXNativeTarget; + buildConfigurationList = 6FE9400E20D592DA008C9FE4 /* Build configuration list for PBXNativeTarget "TFLiteBenchmark" */; + buildPhases = ( + 6FE93FF420D592D8008C9FE4 /* Sources */, + 6FE93FF520D592D8008C9FE4 /* Frameworks */, + 6FE93FF620D592D8008C9FE4 /* Resources */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = TFLiteBenchmark; + productName = TFLiteBenchmark; + productReference = 6FE93FF820D592D8008C9FE4 /* TFLiteBenchmark.app */; + productType = "com.apple.product-type.application"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + 6FE93FF020D592D8008C9FE4 /* Project object */ = { + isa = PBXProject; + attributes = { + LastUpgradeCheck = 1000; + ORGANIZATIONNAME = Example; + TargetAttributes = { + 6FE93FF720D592D8008C9FE4 = { + CreatedOnToolsVersion = 10.0; + }; + }; + }; + buildConfigurationList = 6FE93FF320D592D8008C9FE4 /* Build configuration list for PBXProject "TFLiteBenchmark" */; + compatibilityVersion = "Xcode 9.3"; + developmentRegion = en; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = 6FE93FEF20D592D8008C9FE4; + productRefGroup = 6FE93FF920D592D8008C9FE4 /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + 6FE93FF720D592D8008C9FE4 /* TFLiteBenchmark */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + 6FE93FF620D592D8008C9FE4 /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 6FE757A120D5AB8100F01636 /* mobilenet_v1_1.0_224.tflite in Resources */, + 6FE9400520D592DA008C9FE4 /* Assets.xcassets in Resources */, + 6FE9400320D592D8008C9FE4 /* Main.storyboard in Resources */, + 6FE7579A20D59CE500F01636 /* benchmark_params.json in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + 6FE93FF420D592D8008C9FE4 /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 6FE9400020D592D8008C9FE4 /* BenchmarkViewController.mm in Sources */, + 6FE9400B20D592DA008C9FE4 /* main.m in Sources */, + 6FE93FFD20D592D8008C9FE4 /* AppDelegate.m in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin PBXVariantGroup section */ + 6FE9400120D592D8008C9FE4 /* Main.storyboard */ = { + isa = PBXVariantGroup; + children = ( + 6FE9400220D592D8008C9FE4 /* Base */, + ); + name = Main.storyboard; + sourceTree = ""; + }; +/* End PBXVariantGroup section */ + +/* Begin XCBuildConfiguration section */ + 6FE9400C20D592DA008C9FE4 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_IDENTITY = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 11.0; + MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; + ONLY_ACTIVE_ARCH = YES; + OTHER_CFLAGS = ""; + OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; + SDKROOT = iphoneos; + }; + name = Debug; + }; + 6FE9400D20D592DA008C9FE4 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_IDENTITY = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 11.0; + MTL_ENABLE_DEBUG_INFO = NO; + OTHER_CFLAGS = ""; + OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; + SDKROOT = iphoneos; + VALIDATE_PRODUCT = YES; + }; + name = Release; + }; + 6FE9400F20D592DA008C9FE4 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CODE_SIGN_STYLE = Automatic; + "HEADER_SEARCH_PATHS[arch=*]" = ( + $SRCROOT/../../../../../../../, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/eigen, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/gemmlowp, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/neon_2_sse, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/farmhash/src, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/flatbuffers/include, + ); + INFOPLIST_FILE = TFLiteBenchmark/Info.plist; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + "LIBRARY_SEARCH_PATHS[arch=*]" = $SRCROOT/../../../../../../../tensorflow/contrib/lite/gen/lib; + PRODUCT_BUNDLE_IDENTIFIER = example.TFLiteBenchmark; + PRODUCT_NAME = "$(TARGET_NAME)"; + TARGETED_DEVICE_FAMILY = "1,2"; + "USER_HEADER_SEARCH_PATHS[arch=*]" = ""; + }; + name = Debug; + }; + 6FE9401020D592DA008C9FE4 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CODE_SIGN_STYLE = Automatic; + "HEADER_SEARCH_PATHS[arch=*]" = ( + $SRCROOT/../../../../../../../, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/eigen, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/gemmlowp, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/neon_2_sse, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/farmhash/src, + $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/flatbuffers/include, + ); + INFOPLIST_FILE = TFLiteBenchmark/Info.plist; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + "LIBRARY_SEARCH_PATHS[arch=*]" = $SRCROOT/../../../../../../../tensorflow/contrib/lite/gen/lib; + PRODUCT_BUNDLE_IDENTIFIER = example.TFLiteBenchmark; + PRODUCT_NAME = "$(TARGET_NAME)"; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + 6FE93FF320D592D8008C9FE4 /* Build configuration list for PBXProject "TFLiteBenchmark" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 6FE9400C20D592DA008C9FE4 /* Debug */, + 6FE9400D20D592DA008C9FE4 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + 6FE9400E20D592DA008C9FE4 /* Build configuration list for PBXNativeTarget "TFLiteBenchmark" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 6FE9400F20D592DA008C9FE4 /* Debug */, + 6FE9401020D592DA008C9FE4 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + }; + rootObject = 6FE93FF020D592D8008C9FE4 /* Project object */; +} diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.h b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.h new file mode 100644 index 0000000000000000000000000000000000000000..a55c03e00b5065e3b149c65f820f11d13c064d87 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.h @@ -0,0 +1,22 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface AppDelegate : UIResponder + +@property(strong, nonatomic) UIWindow *window; + +@end diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.m b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.m new file mode 100644 index 0000000000000000000000000000000000000000..b1165940e9a29ac693d473a1c852b7b0681392fc --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.m @@ -0,0 +1,27 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "AppDelegate.h" + +@interface AppDelegate () + +@end + +@implementation AppDelegate +- (BOOL)application:(UIApplication *)application + didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { + return YES; +} +@end diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/AppIcon.appiconset/Contents.json b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/AppIcon.appiconset/Contents.json new file mode 100644 index 0000000000000000000000000000000000000000..d8db8d65fd79fd541b2b7eba75c7378af3448f9c --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -0,0 +1,98 @@ +{ + "images" : [ + { + "idiom" : "iphone", + "size" : "20x20", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "20x20", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "29x29", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "29x29", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "40x40", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "40x40", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "60x60", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "60x60", + "scale" : "3x" + }, + { + "idiom" : "ipad", + "size" : "20x20", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "20x20", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "29x29", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "29x29", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "40x40", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "40x40", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "76x76", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "76x76", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "83.5x83.5", + "scale" : "2x" + }, + { + "idiom" : "ios-marketing", + "size" : "1024x1024", + "scale" : "1x" + } + ], + "info" : { + "version" : 1, + "author" : "xcode" + } +} \ No newline at end of file diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/Contents.json b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/Contents.json new file mode 100644 index 0000000000000000000000000000000000000000..da4a164c918651cdd1e11dca5cc62c333f097601 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/Contents.json @@ -0,0 +1,6 @@ +{ + "info" : { + "version" : 1, + "author" : "xcode" + } +} \ No newline at end of file diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/LaunchScreen.storyboard b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/LaunchScreen.storyboard new file mode 100644 index 0000000000000000000000000000000000000000..bfa36129419f8bd7ad73581cb9f07b8c6eec3fcf --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/LaunchScreen.storyboard @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/Main.storyboard b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/Main.storyboard new file mode 100644 index 0000000000000000000000000000000000000000..adcfe1ef4e708ea6f87c77f4a740b58e5027d3e5 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/Main.storyboard @@ -0,0 +1,60 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.h b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.h new file mode 100644 index 0000000000000000000000000000000000000000..ec6dea0546060881682c44ad451f4812a2f3d7ea --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.h @@ -0,0 +1,21 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface BenchmarkViewController : UIViewController +@property(weak, nonatomic) IBOutlet UITextView *resultsView; + +@end diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.mm b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.mm new file mode 100644 index 0000000000000000000000000000000000000000..356d5b0e17abc715de9b8f7a20ec7459f3468da1 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.mm @@ -0,0 +1,125 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "BenchmarkViewController.h" +#import +#import +#import +#import +#import "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h" +#import "tensorflow/contrib/lite/tools/benchmark/logging.h" + +namespace { +NSString* FilePathForResourceName(NSString* filename) { + NSString* name = [filename stringByDeletingPathExtension]; + NSString* extension = [filename pathExtension]; + NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension]; + if (file_path == NULL) { + TFLITE_LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." << [extension UTF8String] + << "' in bundle."; + } + return file_path; +} + +NSDictionary* ParseJson() { + NSString* params_json_path = FilePathForResourceName(@"benchmark_params.json"); + NSData* data = [NSData dataWithContentsOfFile:params_json_path]; + return [NSJSONSerialization JSONObjectWithData:data options:kNilOptions error:nil]; +} + +std::string FormatCommandLineParam(NSString* key, NSString* value) { + std::ostringstream stream; + stream << "--" << [key UTF8String] << "=" << [value UTF8String]; + return stream.str(); +} + +// Reads the |benchmark_params.json| to read command line parameters and returns them as a vector of +// strings. +void ReadCommandLineParameters(std::vector* params) { + NSDictionary* param_dict = ParseJson(); + for (NSString* key in param_dict) { + NSString* value = param_dict[key]; + if ([key isEqualToString:@"graph"]) { + value = FilePathForResourceName(value); + } + params->push_back(FormatCommandLineParam(key, value)); + } +} +std::vector StringVecToCharPtrVec(const std::vector& str_vec) { + std::vector charptr_vec; + std::transform(str_vec.begin(), str_vec.end(), std::back_inserter(charptr_vec), + [](const std::string& s) -> char* { return const_cast(s.c_str()); }); + return charptr_vec; +} + +class ResultsListener : public tflite::benchmark::BenchmarkListener { + public: + void OnBenchmarkEnd(const tflite::benchmark::BenchmarkResults& results) override; + std::string Results() { return results_; } + + private: + std::string results_; +}; + +void OutputMicrosecondsStatToStream(const tensorflow::Stat& time_us, + const std::string& prefix, std::ostringstream* stream) { + *stream << prefix << "Num runs: " << time_us.count() << "\n"; + + *stream << prefix << "Average: " << time_us.avg() / 1e3 << " ms\n"; + *stream << prefix << "Min: " << time_us.min() / 1e3 << " ms \n"; + *stream << prefix << "Max: " << time_us.max() / 1e3 << " ms \n"; + *stream << prefix << "Std deviation: " << time_us.std_deviation() / 1e3 << " ms\n"; +} + +void ResultsListener::OnBenchmarkEnd(const tflite::benchmark::BenchmarkResults& results) { + std::ostringstream stream; + const std::string prefix = " - "; + stream << "Startup latency: "; + stream << results.startup_latency_us() / 1e3 << " ms\n"; + stream << "\nInference:\n"; + OutputMicrosecondsStatToStream(results.inference_time_us(), prefix, &stream); + stream << "\nWarmup:\n"; + OutputMicrosecondsStatToStream(results.warmup_time_us(), prefix, &stream); + + results_ = stream.str(); +} + +std::string RunBenchmark() { + ResultsListener listener; + tflite::benchmark::BenchmarkTfLiteModel benchmark; + benchmark.AddListener(&listener); + // TODO(shashishekhar): Passing arguments like this is brittle, refactor the BenchmarkParams + // so that it contains arguments for BenchmarkTfLiteModel and set parameters using BenchmarkParams + std::vector command_line_params; + // Benchmark model expects first arg to be program name. + // push a string for name of program. + command_line_params.push_back("benchmark_tflite_model"); + ReadCommandLineParameters(&command_line_params); + std::vector argv = StringVecToCharPtrVec(command_line_params); + int argc = static_cast(argv.size()); + benchmark.Run(argc, argv.data()); + return listener.Results(); +} +} // namespace + +@interface BenchmarkViewController () +@end + +@implementation BenchmarkViewController +- (IBAction)onBenchmarkModel:(UIButton*)sender { + std::string results = RunBenchmark(); + [_resultsView setText:[NSString stringWithUTF8String:results.c_str()]]; +} +@end diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Info.plist b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Info.plist new file mode 100644 index 0000000000000000000000000000000000000000..96051cf08ff54b51f458eca6f0126dd99dfc51dc --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Info.plist @@ -0,0 +1,43 @@ + + + + + UILaunchStoryboardName + Main + CFBundleDevelopmentRegion + $(DEVELOPMENT_LANGUAGE) + CFBundleExecutable + $(EXECUTABLE_NAME) + CFBundleIdentifier + $(PRODUCT_BUNDLE_IDENTIFIER) + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + $(PRODUCT_NAME) + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleVersion + 1 + LSRequiresIPhoneOS + + UIMainStoryboardFile + Main + UIRequiredDeviceCapabilities + + armv7 + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + UIInterfaceOrientationPortraitUpsideDown + UIInterfaceOrientationLandscapeLeft + UIInterfaceOrientationLandscapeRight + + + diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/benchmark_data/benchmark_params.json b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/benchmark_data/benchmark_params.json new file mode 100644 index 0000000000000000000000000000000000000000..d344a7a5efaef53500bc0f88d29ca7aecf59290a --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/benchmark_data/benchmark_params.json @@ -0,0 +1,10 @@ +{ + "benchmark_name" : "mobile_net_benchmark", + "num_threads" : "4", + "num_runs" : "20", + "warmup_runs" : "1", + "graph" : "mobilenet_v1_1.0_224.tflite", + "input_layer" : "input", + "input_layer_shape" : "1,224,224,3", + "run_delay" : "-1" +} diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/main.m b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/main.m new file mode 100644 index 0000000000000000000000000000000000000000..1e70b9cd1d82f320ec048642520dbc54dc0f7934 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/main.m @@ -0,0 +1,23 @@ +// Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import "AppDelegate.h" + +int main(int argc, char* argv[]) { + @autoreleasepool { + return UIApplicationMain(argc, argv, nil, NSStringFromClass([AppDelegate class])); + } +} diff --git a/tensorflow/contrib/lite/tools/verifier_test.cc b/tensorflow/contrib/lite/tools/verifier_test.cc index ce8a7857d2dd66b12e9ea970911ef1dd01e4550e..ad7d59ecb41a0c81a6a4d8edae5fa6b4b5a7bede 100644 --- a/tensorflow/contrib/lite/tools/verifier_test.cc +++ b/tensorflow/contrib/lite/tools/verifier_test.cc @@ -41,7 +41,7 @@ class TfLiteFlatbufferModelBuilder { } TfLiteFlatbufferModelBuilder(const std::vector& builtin_ops, - const std::vector& custom_ops) { + const std::vector& custom_ops) { buffers_.push_back( CreateBuffer(builder_, builder_.CreateVector(std::vector{}))); @@ -194,8 +194,8 @@ TEST(VerifyModel, TensorBufferIsNotValid) { /*operators=*/0, builder.CreateString("Main"))}); auto buffers = builder.CreateVector(std::vector>{ - CreateBuffer(builder, - builder.CreateVector(std::vector{1, 2, 3, 4, 5, 6})), + CreateBuffer(builder, builder.CreateVector( + std::vector{1, 2, 3, 4, 5, 6})), }); auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, /*operator_codes=*/0, diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 5d4682ec9f4b8c5864383bd1d2f4c0b41a11baad..5a080cceabb55c307dcd1a457a9e30d24e0bd172 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -24,6 +24,7 @@ import six from tensorflow.contrib import lookup from tensorflow.python.client import session +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl @@ -1396,15 +1397,22 @@ class KeyValueTensorInitializerTest(test.TestCase): class IndexTableFromTensor(test.TestCase): + @test_util.run_in_graph_and_eager_modes() def test_index_table_from_tensor_with_tensor_init(self): - with self.test_session(): + table = lookup.index_table_from_tensor( + mapping=("brain", "salad", "surgery"), num_oov_buckets=1) + + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.OpError): + self.evaluate(table.lookup( + constant_op.constant(("salad", "surgery", "tarkus")))) + else: + # Reinitializing a table in eager should work. table = lookup.index_table_from_tensor( mapping=("brain", "salad", "surgery"), num_oov_buckets=1) - ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus"))) - - self.assertRaises(errors_impl.OpError, ids.eval) - lookup_ops.tables_initializer().run() - self.assertAllEqual((1, 2, 3), ids.eval()) + self.evaluate(lookup_ops.tables_initializer()) + ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus"))) + self.assertAllEqual((1, 2, 3), self.evaluate(ids)) def test_int32_index_table_from_tensor_with_tensor_init(self): with self.test_session(): diff --git a/tensorflow/contrib/makefile/build_all_android.sh b/tensorflow/contrib/makefile/build_all_android.sh index fc88f59e0948e1d3ed7cce9b809bf30ba280af12..fb9e77ae1bcfc3404f1fdf90ab2697a4e79a9836 100755 --- a/tensorflow/contrib/makefile/build_all_android.sh +++ b/tensorflow/contrib/makefile/build_all_android.sh @@ -30,6 +30,14 @@ arm64-v8a armeabi armeabi-v7a mips mips64 x86 x86_64 tegra)" exit 1 } +echo "********************************************************************" +echo "TensorFlow Lite is the recommended library for mobile and embedded machine learning inference." +echo "You are currently using an older version. Please switch over to TensorFlow Lite." +echo "" +echo "Link to the code: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite" +echo "********************************************************************" +echo "" + if [[ -z "${NDK_ROOT}" ]]; then echo "NDK_ROOT should be set as an environment variable" 1>&2 exit 1 diff --git a/tensorflow/contrib/makefile/build_all_ios.sh b/tensorflow/contrib/makefile/build_all_ios.sh index 0a458a27b3ac9b1a24b0f42de2f0166d515e8cd9..1d4677ef4bd1e8811998d1464e63902544153a49 100755 --- a/tensorflow/contrib/makefile/build_all_ios.sh +++ b/tensorflow/contrib/makefile/build_all_ios.sh @@ -31,6 +31,14 @@ usage() { exit 1 } +echo "********************************************************************" +echo "TensorFlow Lite is the recommended library for mobile and embedded machine learning inference." +echo "You are currently using an older version. Please switch over to TensorFlow Lite." +echo "" +echo "Link to the code: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite" +echo "********************************************************************" +echo "" + DEFAULT_ARCH="i386 x86_64 armv7 armv7s arm64" while getopts "a:g:T" opt_name; do case "$opt_name" in diff --git a/tensorflow/contrib/metrics/BUILD b/tensorflow/contrib/metrics/BUILD index 4f2c82ca23011667662c74507fcbd99bcde4c7c0..66cb493e5c5bb9b8645e87dc7f5b274d916f64fc 100644 --- a/tensorflow/contrib/metrics/BUILD +++ b/tensorflow/contrib/metrics/BUILD @@ -77,7 +77,31 @@ py_test( py_test( name = "metric_ops_test", srcs = ["python/ops/metric_ops_test.py"], - shard_count = 16, + shard_count = 30, + srcs_version = "PY2AND3", + tags = ["noasan"], # times out b/63678675 + deps = [ + ":metrics_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +py_test( + name = "metric_ops_large_test", + size = "large", + srcs = ["python/ops/metric_ops_large_test.py"], srcs_version = "PY2AND3", tags = ["noasan"], # times out b/63678675 deps = [ diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index a6be2084aae6bb05f958929b45977ed21b570603..b14202ff9ec38016f926ee37c8acbd2bbb4c6ef5 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -1064,7 +1064,7 @@ def streaming_auc(predictions, name=name) -def _compute_dynamic_auc(labels, predictions, curve='ROC'): +def _compute_dynamic_auc(labels, predictions, curve='ROC', weights=None): """Computes the apporixmate AUC by a Riemann sum with data-derived thresholds. Computes the area under the ROC or PR curve using each prediction as a @@ -1077,13 +1077,22 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'): predictions: A 1-D `Tensor` of predictions whose values are `float64`. curve: The name of the curve to be computed, 'ROC' for the Receiving Operating Characteristic or 'PR' for the Precision-Recall curve. + weights: A 1-D `Tensor` of weights whose values are `float64`. Returns: A scalar `Tensor` containing the area-under-curve value for the input. """ - # Count the total number of positive and negative labels in the input. + # Compute the total weight and the total positive weight. size = array_ops.size(predictions) - total_positive = math_ops.cast(math_ops.reduce_sum(labels), dtypes.int32) + if weights is None: + weights = array_ops.ones_like(labels, dtype=dtypes.float64) + labels, predictions, weights = metrics_impl._remove_squeezable_dimensions( + labels, predictions, weights) + total_weight = math_ops.reduce_sum(weights) + total_positive = math_ops.reduce_sum( + array_ops.where( + math_ops.greater(labels, 0), weights, + array_ops.zeros_like(labels, dtype=dtypes.float64))) def continue_computing_dynamic_auc(): """Continues dynamic auc computation, entered if labels are not all equal. @@ -1091,9 +1100,11 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'): Returns: A scalar `Tensor` containing the area-under-curve value. """ - # Sort the predictions descending, and the corresponding labels as well. + # Sort the predictions descending, keeping the same order for the + # corresponding labels and weights. ordered_predictions, indices = nn.top_k(predictions, k=size) ordered_labels = array_ops.gather(labels, indices) + ordered_weights = array_ops.gather(weights, indices) # Get the counts of the unique ordered predictions. _, _, counts = array_ops.unique_with_counts(ordered_predictions) @@ -1103,23 +1114,39 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'): array_ops.pad(math_ops.cumsum(counts), paddings=[[1, 0]]), dtypes.int32) # Count the positives to the left of the split indices. - positives = math_ops.cast( - array_ops.pad(math_ops.cumsum(ordered_labels), paddings=[[1, 0]]), - dtypes.int32) - true_positives = array_ops.gather(positives, splits) + true_positives = array_ops.gather( + array_ops.pad( + math_ops.cumsum( + array_ops.where( + math_ops.greater(ordered_labels, 0), ordered_weights, + array_ops.zeros_like(ordered_labels, + dtype=dtypes.float64))), + paddings=[[1, 0]]), splits) if curve == 'ROC': - # Count the negatives to the left of every split point and the total - # number of negatives for computing the FPR. - false_positives = math_ops.subtract(splits, true_positives) - total_negative = size - total_positive + # Compute the weight of the negatives to the left of every split point and + # the total weight of the negatives number of negatives for computing the + # FPR. + false_positives = array_ops.gather( + array_ops.pad( + math_ops.cumsum( + array_ops.where( + math_ops.less(ordered_labels, 1), ordered_weights, + array_ops.zeros_like( + ordered_labels, dtype=dtypes.float64))), + paddings=[[1, 0]]), splits) + total_negative = total_weight - total_positive x_axis_values = math_ops.truediv(false_positives, total_negative) y_axis_values = math_ops.truediv(true_positives, total_positive) elif curve == 'PR': x_axis_values = math_ops.truediv(true_positives, total_positive) # For conformance, set precision to 1 when the number of positive # classifications is 0. + positives = array_ops.gather( + array_ops.pad(math_ops.cumsum(ordered_weights), paddings=[[1, 0]]), + splits) y_axis_values = array_ops.where( - math_ops.greater(splits, 0), math_ops.truediv(true_positives, splits), + math_ops.greater(splits, 0), + math_ops.truediv(true_positives, positives), array_ops.ones_like(true_positives, dtype=dtypes.float64)) # Calculate trapezoid areas. @@ -1133,7 +1160,7 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'): return control_flow_ops.cond( math_ops.logical_or( math_ops.equal(total_positive, 0), math_ops.equal( - total_positive, size)), + total_positive, total_weight)), true_fn=lambda: array_ops.constant(0, dtypes.float64), false_fn=continue_computing_dynamic_auc) @@ -1143,7 +1170,8 @@ def streaming_dynamic_auc(labels, curve='ROC', metrics_collections=(), updates_collections=(), - name=None): + name=None, + weights=None): """Computes the apporixmate AUC by a Riemann sum with data-derived thresholds. USAGE NOTE: this approach requires storing all of the predictions and labels @@ -1168,6 +1196,8 @@ def streaming_dynamic_auc(labels, should be added to. name: An optional name for the variable_scope that contains the metric variables. + weights: A 'Tensor' of non-negative weights whose values are castable to + `float64`. Will be flattened into a 1-D `Tensor`. Returns: auc: A scalar `Tensor` containing the current area-under-curve value. @@ -1195,14 +1225,24 @@ def streaming_dynamic_auc(labels, check_ops.assert_less_equal( labels, array_ops.ones_like(labels, dtypes.int64), - message='labels must be 0 or 1, at least one is >1') + message='labels must be 0 or 1, at least one is >1'), ]): preds_accum, update_preds = streaming_concat( predictions, name='concat_preds') labels_accum, update_labels = streaming_concat( labels, name='concat_labels') - update_op = control_flow_ops.group(update_labels, update_preds) - auc = _compute_dynamic_auc(labels_accum, preds_accum, curve=curve) + if weights is not None: + weights = array_ops.reshape( + math_ops.cast(weights, dtypes.float64), [-1]) + weights_accum, update_weights = streaming_concat( + weights, name='concat_weights') + update_op = control_flow_ops.group(update_labels, update_preds, + update_weights) + else: + weights_accum = None + update_op = control_flow_ops.group(update_labels, update_preds) + auc = _compute_dynamic_auc( + labels_accum, preds_accum, curve=curve, weights=weights_accum) if updates_collections: ops.add_to_collections(updates_collections, update_op) if metrics_collections: diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7acfc383eb9a659a600752cf57b4978daa8a07bc --- /dev/null +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py @@ -0,0 +1,66 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Large tests for metric_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.contrib.metrics.python.ops import metric_ops +from tensorflow.python.framework import dtypes as dtypes_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class StreamingPrecisionRecallAtEqualThresholdsLargeTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testLargeCase(self): + shape = [32, 512, 256, 1] + predictions = random_ops.random_uniform( + shape, 0.0, 1.0, dtype=dtypes_lib.float32) + labels = math_ops.greater(random_ops.random_uniform(shape, 0.0, 1.0), 0.5) + + result, update_op = metric_ops.precision_recall_at_equal_thresholds( + labels=labels, predictions=predictions, num_thresholds=201) + # Run many updates, enough to cause highly inaccurate values if the + # code used float32 for accumulation. + num_updates = 71 + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + for _ in xrange(num_updates): + sess.run(update_op) + + prdata = sess.run(result) + + # Since we use random values, we won't know the tp/fp/tn/fn values, but + # tp and fp at threshold 0 should be the total number of positive and + # negative labels, hence their sum should be total number of pixels. + expected_value = 1.0 * np.product(shape) * num_updates + got_value = prdata.tp[0] + prdata.fp[0] + # They should be at least within 1. + self.assertNear(got_value, expected_value, 1.0) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index 4ccba4a253b142dbdec9861fc0d80247c2ded50b..a09fc4abd461323d67e914c70932688816fed764 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -2127,6 +2127,44 @@ class StreamingDynamicAUCTest(test.TestCase): sess.run(update_op) self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-5) + def testWithWeights(self): + batch_size = 10 + num_batches = 100 + labels = np.array([]) + predictions = np.array([]) + weights = np.array([]) + tf_labels = variables.Variable( + array_ops.ones(batch_size, dtypes_lib.int32), + collections=[ops.GraphKeys.LOCAL_VARIABLES], + dtype=dtypes_lib.int32) + tf_predictions = variables.Variable( + array_ops.ones(batch_size), + collections=[ops.GraphKeys.LOCAL_VARIABLES], + dtype=dtypes_lib.float32) + tf_weights = variables.Variable( + array_ops.ones(batch_size), + collections=[ops.GraphKeys.LOCAL_VARIABLES], + dtype=dtypes_lib.float32) + auc, update_op = metrics.streaming_dynamic_auc(tf_labels, + tf_predictions, + weights=tf_weights) + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + for _ in xrange(num_batches): + new_labels = np.random.randint(0, 2, size=batch_size) + noise = np.random.uniform(-0.2, 0.2, size=batch_size) + new_predictions = 0.4 + 0.2 * new_labels + noise + new_weights = np.random.uniform(0.0, 3.0, size=batch_size) + labels = np.concatenate([labels, new_labels]) + predictions = np.concatenate([predictions, new_predictions]) + weights = np.concatenate([weights, new_weights]) + sess.run([tf_labels.assign(new_labels), + tf_predictions.assign(new_predictions), + tf_weights.assign(new_weights)]) + sess.run(update_op) + expected_auc = _np_auc(predictions, labels, weights) + self.assertAlmostEqual(expected_auc, auc.eval()) + class AucWithConfidenceIntervalsTest(test.TestCase): @@ -2391,33 +2429,6 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): for _ in range(3): self._testResultsEqual(initial_result, result) - def testLargeCase(self): - shape = [32, 512, 256, 1] - predictions = random_ops.random_uniform( - shape, 0.0, 1.0, dtype=dtypes_lib.float32) - labels = math_ops.greater(random_ops.random_uniform(shape, 0.0, 1.0), 0.5) - - result, update_op = metric_ops.precision_recall_at_equal_thresholds( - labels=labels, predictions=predictions, num_thresholds=201) - # Run many updates, enough to cause highly inaccurate values if the - # code used float32 for accumulation. - num_updates = 71 - - with self.test_session() as sess: - sess.run(variables.local_variables_initializer()) - for _ in xrange(num_updates): - sess.run(update_op) - - prdata = sess.run(result) - - # Since we use random values, we won't know the tp/fp/tn/fn values, but - # tp and fp at threshold 0 should be the total number of positive and - # negative labels, hence their sum should be total number of pixels. - expected_value = 1.0 * np.product(shape) * num_updates - got_value = prdata.tp[0] + prdata.fp[0] - # They should be at least within 1. - self.assertNear(got_value, expected_value, 1.0) - def _testCase(self, predictions, labels, @@ -4726,199 +4737,204 @@ class StreamingSparseRecallTest(test.TestCase): self._test_sparse_recall_at_top_k( labels, top_k_predictions, expected=1.0 / 2) - def test_one_label_at_k1_weighted(self): + def _test_one_label_at_k1_weighted(self, labels): predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]] top_k_predictions = [[3], [3]] - sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1], - [0, 0, 1, 0]]) - dense_labels = np.array([[3], [2]], dtype=np.int64) - for labels in (sparse_labels, dense_labels): - # Class 3: 1 label, 2 predictions, 1 correct. - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=NAN, class_id=3, weights=(0.0,)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=1.0 / 1, - class_id=3, - weights=(1.0,)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=1.0 / 1, - class_id=3, - weights=(1.0,)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=1.0 / 1, - class_id=3, - weights=(2.0,)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=1.0 / 1, - class_id=3, - weights=(2.0,)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=NAN, - class_id=3, - weights=(0.0, 0.0)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=NAN, - class_id=3, - weights=(0.0, 0.0)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=NAN, - class_id=3, - weights=(0.0, 1.0)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=NAN, - class_id=3, - weights=(0.0, 1.0)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=1.0 / 1, - class_id=3, - weights=(1.0, 0.0)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=1.0 / 1, - class_id=3, - weights=(1.0, 0.0)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=1.0 / 1, - class_id=3, - weights=(1.0, 1.0)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=1.0 / 1, - class_id=3, - weights=(1.0, 1.0)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=2.0 / 2, - class_id=3, - weights=(2.0, 3.0)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=2.0 / 2, - class_id=3, - weights=(2.0, 3.0)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=3.0 / 3, - class_id=3, - weights=(3.0, 2.0)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=3.0 / 3, - class_id=3, - weights=(3.0, 2.0)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=0.3 / 0.3, - class_id=3, - weights=(0.3, 0.6)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=0.3 / 0.3, - class_id=3, - weights=(0.3, 0.6)) - self._test_streaming_sparse_recall_at_k( - predictions, - labels, - k=1, - expected=0.6 / 0.6, - class_id=3, - weights=(0.6, 0.3)) - self._test_sparse_recall_at_top_k( - labels, - top_k_predictions, - expected=0.6 / 0.6, - class_id=3, - weights=(0.6, 0.3)) + # Class 3: 1 label, 2 predictions, 1 correct. + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, class_id=3, weights=(0.0,)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=1.0 / 1, + class_id=3, + weights=(1.0,)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(1.0,)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=1.0 / 1, + class_id=3, + weights=(2.0,)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(2.0,)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=NAN, + class_id=3, + weights=(0.0, 0.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=NAN, + class_id=3, + weights=(0.0, 0.0)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=NAN, + class_id=3, + weights=(0.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=NAN, + class_id=3, + weights=(0.0, 1.0)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=1.0 / 1, + class_id=3, + weights=(1.0, 0.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(1.0, 0.0)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=1.0 / 1, + class_id=3, + weights=(1.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=1.0 / 1, + class_id=3, + weights=(1.0, 1.0)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=2.0 / 2, + class_id=3, + weights=(2.0, 3.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=2.0 / 2, + class_id=3, + weights=(2.0, 3.0)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=3.0 / 3, + class_id=3, + weights=(3.0, 2.0)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=3.0 / 3, + class_id=3, + weights=(3.0, 2.0)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=0.3 / 0.3, + class_id=3, + weights=(0.3, 0.6)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=0.3 / 0.3, + class_id=3, + weights=(0.3, 0.6)) + self._test_streaming_sparse_recall_at_k( + predictions, + labels, + k=1, + expected=0.6 / 0.6, + class_id=3, + weights=(0.6, 0.3)) + self._test_sparse_recall_at_top_k( + labels, + top_k_predictions, + expected=0.6 / 0.6, + class_id=3, + weights=(0.6, 0.3)) - # All classes: 2 labels, 2 predictions, 1 correct. - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=NAN, weights=(0.0,)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=NAN, weights=(0.0,)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=1.0 / 2, weights=(1.0,)) + # All classes: 2 labels, 2 predictions, 1 correct. + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=NAN, weights=(0.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=NAN, weights=(0.0,)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2, weights=(1.0,)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=1.0 / 2, weights=(2.0,)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2, weights=(2.0,)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=1.0 / 1, weights=(1.0, 0.0)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 1, weights=(1.0, 0.0)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=0.0 / 1, weights=(0.0, 1.0)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.0 / 1, weights=(0.0, 1.0)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=1.0 / 2, weights=(1.0, 1.0)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=1.0 / 2, weights=(1.0, 1.0)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=1.0 / 2, weights=(1.0, 1.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=1.0 / 2, weights=(1.0, 1.0)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=2.0 / 5, weights=(2.0, 3.0)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=2.0 / 5, weights=(2.0, 3.0)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=3.0 / 5, weights=(3.0, 2.0)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=3.0 / 5, weights=(3.0, 2.0)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=3.0 / 5, weights=(3.0, 2.0)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=3.0 / 5, weights=(3.0, 2.0)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=0.3 / 0.9, weights=(0.3, 0.6)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=0.3 / 0.9, weights=(0.3, 0.6)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=0.3 / 0.9, weights=(0.3, 0.6)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.3 / 0.9, weights=(0.3, 0.6)) - self._test_streaming_sparse_recall_at_k( - predictions, labels, k=1, expected=0.6 / 0.9, weights=(0.6, 0.3)) - self._test_sparse_recall_at_top_k( - labels, top_k_predictions, expected=0.6 / 0.9, weights=(0.6, 0.3)) + self._test_streaming_sparse_recall_at_k( + predictions, labels, k=1, expected=0.6 / 0.9, weights=(0.6, 0.3)) + self._test_sparse_recall_at_top_k( + labels, top_k_predictions, expected=0.6 / 0.9, weights=(0.6, 0.3)) + + def test_one_label_at_k1_weighted_sparse_labels(self): + sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1], + [0, 0, 1, 0]]) + self._test_one_label_at_k1_weighted(sparse_labels) + + def test_one_label_at_k1_weighted_dense_labels(self): + dense_labels = np.array([[3], [2]], dtype=np.int64) + self._test_one_label_at_k1_weighted(dense_labels) def test_three_labels_at_k5_nan(self): predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py index e4e5ccc33472ad5a12bd8111fb1ff6ebbd6f45f9..ef34f7bf7bf3eba047b50ce8abf883b0ed741a63 100644 --- a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py @@ -26,26 +26,32 @@ from tensorflow.python.training import optimizer class LossScaleOptimizer(optimizer.Optimizer): + # TODO(jamesqin): move mixed precision training explanation to __init__ + # docstring. """An optimizer that applies loss scaling in backprop. - This class is useful for mixed precision training on GPUs (or other potential - accelerators), which is an approach to improve compute throughput without loss - of model quality. - - The commmon configuration of mixed precision models is the following: - * variables are kept in high precision (e.g. float32). - * computations are done in lower precision (e.g. float16). variables are - casted to lower precision before they're used. - * (in training), final gradients are casted back to variable precision and get - applied. - - Because computations happen in lower precision, gradients in the backprop pass - might underflow in the smaller dynamic range, causing a model to converge at a - suboptimal level. This optimizer multiplies the loss by a factor before - backprop starts to prevent underflow. Before gradients are applied, they are - casted to higher precision and down-scaled by the same factor, so - mathematically the variable updates are no different from regular - same-precision training. + This class is useful for "mixed precision training" on GPUs (or other + potential accelerators), an approach to improve compute throughput without + compromising model quality. + + The canonical way to perform mixed precision training is the following: + * Model variables are kept in high precision (e.g. float32). + * Computations are done in lower precision (e.g. float16), which enjoys + performance speedup by virtue of hardware support. Variables are casted to + lower precision before they're used. + * Final gradients are casted back to high precision dtype, then used to update + variables. + + The side-effect of performing computation in lower precision, is that it comes + with smaller numerical range. During backproping, small gradients might + underflow in the reduced numerical range, causing a model to converge at + suboptimal level. + + To prevent underflow, this optimizer multiplies the loss by a factor before + backprop starts. Consequently, the gradients are linearly scaled up by the + same factor, thus not falling into the underflow zone. After that, to perserve + the correctness of backprop, the gradients are down-scaled by the same factor, + casted to the (higher) variable precision, then applied on the variables. See [Nvidia's manual on mixed precision training]( https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html) diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD index 334e70318dd88185cecd93ebeb2587861b7999b9..7cfdf0f607033479f03827ca20f4ad609e51cdfe 100644 --- a/tensorflow/contrib/nccl/BUILD +++ b/tensorflow/contrib/nccl/BUILD @@ -97,18 +97,19 @@ tf_gen_op_wrapper_py( deps = [":nccl_ops_op_lib"], ) +# Test only nccl ops lib without dso to test behavior when NCCL lib is not +# installed. See nccl_dependency_test for more details. +# +# Users should use the public nccl_py lib that also adds the dso. tf_custom_op_py_library( - name = "nccl_py", + name = "nccl_ops_lib_without_dso", srcs = [ "__init__.py", "python/ops/nccl_ops.py", ], - dso = [":python/ops/_nccl_ops.so"], kernels = if_cuda([":nccl_kernels"]) + [ ":nccl_ops_op_lib", ], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], deps = [ ":nccl_ops", "//tensorflow/contrib/util:util_py", @@ -120,6 +121,15 @@ tf_custom_op_py_library( ], ) +tf_custom_op_py_library( + name = "nccl_py", + dso = [":python/ops/_nccl_ops.so"], + visibility = ["//visibility:public"], + deps = [ + ":nccl_ops_lib_without_dso", + ], +) + cuda_py_test( name = "nccl_ops_test", size = "small", @@ -141,3 +151,25 @@ cuda_py_test( "notap", ], ) + +cuda_py_test( + name = "nccl_dependency_test", + size = "small", + srcs = ["python/ops/nccl_dependency_test.py"], + additional_deps = [ + ":nccl_ops_lib_without_dso", + "//tensorflow/python:constant_op", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:util", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform_test", + ], + # Disable this test internally as static linking is used internally and only + # run for OSS to verify that NCCL is an optional dynamic dependency. + tags = [ + "manual", + "noguitar", + "notap", + ], +) diff --git a/tensorflow/contrib/nccl/python/ops/nccl_dependency_test.py b/tensorflow/contrib/nccl/python/ops/nccl_dependency_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c766080dbee7c9a6f4383ef6fa8cade7bba158af --- /dev/null +++ b/tensorflow/contrib/nccl/python/ops/nccl_dependency_test.py @@ -0,0 +1,59 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Dependency test for nccl to test behavior when NCCL is not installed.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import nccl +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops +from tensorflow.python.platform import test +from tensorflow.python.util import tf_inspect + + +class NcclDependencyTest(test.TestCase): + """Verifies that importing nccl ops lib does not fail even if NCCL is not + installed but nccl ops throws an exception on use if NCCL is not installed. + """ + + def test_nccl_ops(self): + """Tests behavior of nccl ops when NCCL is not installed.""" + + public_methods = [ + m[0] + for m in tf_inspect.getmembers(nccl, tf_inspect.isfunction) + if not m[0].startswith('_') + ] + for method_name in public_methods: + with ops.device('/device:CPU:0'): + tensor = constant_op.constant(1) + + if method_name == 'broadcast': + arg = tensor + else: + arg = [tensor] + + nccl_op = getattr(nccl, method_name) + with ops.device('/device:CPU:0'): + with self.assertRaisesRegexp(errors_impl.NotFoundError, + r'cannot open shared object file'): + nccl_op(arg) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops.py b/tensorflow/contrib/nccl/python/ops/nccl_ops.py index 794372a1f4b0dcc41bcf0da611f5bc2ec9301973..029b01412d96ca03d4ecf7bf4d7d9872864e3ddc 100644 --- a/tensorflow/contrib/nccl/python/ops/nccl_ops.py +++ b/tensorflow/contrib/nccl/python/ops/nccl_ops.py @@ -26,8 +26,10 @@ from tensorflow.python.framework import device from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader -_nccl_ops_so = loader.load_op_library( - resource_loader.get_path_to_datafile('_nccl_ops.so')) + +_nccl_ops_so = None +_module_lock = threading.Lock() +_shared_name_counter = 0 def all_sum(tensors): @@ -180,7 +182,7 @@ def broadcast(tensor): A tensor with the value of `src_tensor`, which can be used as input to ops on other GPU devices. """ - _check_graph_mode() + _validate_and_load_nccl_so() _check_device(tensor) with ops.device(tensor.device): @@ -212,7 +214,7 @@ def _apply_all_reduce(reduction, tensors): """Helper function for all_* functions.""" if not tensors: raise ValueError('Must pass >0 tensors to all reduce operations') - _check_graph_mode() + _validate_and_load_nccl_so() shared_name = _get_shared_name() res = [] @@ -234,7 +236,7 @@ def _apply_reduce(reduction, tensors): """Helper function for reduce_* functions.""" if not tensors: raise ValueError('Must pass >0 tensors to reduce operations') - _check_graph_mode() + _validate_and_load_nccl_so() for t in tensors: _check_device(t) @@ -246,14 +248,10 @@ def _apply_reduce(reduction, tensors): return result -_lock = threading.Lock() -_shared_name_counter = 0 - - def _get_shared_name(): global _shared_name_counter - with _lock: + with _module_lock: val = _shared_name_counter _shared_name_counter += 1 return 'c%s' % val @@ -266,6 +264,25 @@ def _check_device(tensor, expected=None): raise ValueError('Expected device %s, got %s' % (expected, tensor.device)) -def _check_graph_mode(): +def _maybe_load_nccl_ops_so(): + """Loads nccl ops so if it hasn't been loaded already.""" + + with _module_lock: + global _nccl_ops_so + if not _nccl_ops_so: + _nccl_ops_so = loader.load_op_library( + resource_loader.get_path_to_datafile('_nccl_ops.so')) + + +def _validate_and_load_nccl_so(): + """Validates calling context and loads nccl ops so file. + + Raises: + ValueError: Ops are not supported. + errors_impl.NotFoundError: nccl library is not installed. + """ + if context.executing_eagerly(): raise ValueError('Nccl ops are not supported in eager mode') + + _maybe_load_nccl_ops_so() diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index 13aa1d7e7a11877373a848c1ba865aa418790cd0..bbdf962d0480e52045d31f65b3d137ed3f11f2f1 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -19,6 +19,7 @@ py_library( "python/training/drop_stale_gradient_optimizer.py", "python/training/elastic_average_optimizer.py", "python/training/external_optimizer.py", + "python/training/ggt.py", "python/training/lazy_adam_optimizer.py", "python/training/model_average_optimizer.py", "python/training/moving_average_optimizer.py", @@ -28,15 +29,19 @@ py_library( "python/training/reg_adagrad_optimizer.py", "python/training/sign_decay.py", "python/training/variable_clipping_optimizer.py", + "python/training/weight_decay_optimizers.py", ], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/optimizer_v2:optimizer_v2_py", "//tensorflow/python:array_ops", "//tensorflow/python:clip_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_ops", "//tensorflow/python:gradients", "//tensorflow/python:init_ops", + "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", "//tensorflow/python:platform", "//tensorflow/python:state_ops", @@ -194,6 +199,25 @@ py_test( ], ) +py_test( + name = "weight_decay_optimizers_test", + srcs = ["python/training/weight_decay_optimizers_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:session", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + tf_py_test( name = "drop_stale_gradient_optimizer_test", srcs = ["python/training/drop_stale_gradient_optimizer_test.py"], @@ -302,3 +326,21 @@ py_test( "//third_party/py/numpy", ], ) + +py_test( + name = "ggt_test", + srcs = ["python/training/ggt_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py index 4c13c8e247185213b798eb733ddcf65a07a8f64d..157ed6a278bb699724d3854426d780a3a58823db 100644 --- a/tensorflow/contrib/opt/__init__.py +++ b/tensorflow/contrib/opt/__init__.py @@ -27,10 +27,12 @@ from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import * from tensorflow.contrib.opt.python.training.moving_average_optimizer import * from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import * from tensorflow.contrib.opt.python.training.nadam_optimizer import * +from tensorflow.contrib.opt.python.training.weight_decay_optimizers import * from tensorflow.contrib.opt.python.training.powersign import * from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import * from tensorflow.contrib.opt.python.training.elastic_average_optimizer import * from tensorflow.contrib.opt.python.training.model_average_optimizer import * +from tensorflow.contrib.opt.python.training.ggt import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented @@ -46,6 +48,10 @@ _allowed_symbols = [ 'LazyAdamOptimizer', 'NadamOptimizer', 'MovingAverageOptimizer', + 'MomentumWOptimizer', + 'AdamWOptimizer', + 'DecoupledWeightDecayExtension', + 'extend_with_decoupled_weight_decay', 'ScipyOptimizerInterface', 'VariableClippingOptimizer', 'MultitaskOptimizerWrapper', @@ -53,7 +59,8 @@ _allowed_symbols = [ 'ElasticAverageOptimizer', 'ElasticAverageCustomGetter', 'ModelAverageOptimizer', - 'ModelAverageCustomGetter' + 'ModelAverageCustomGetter', + 'GGTOptimizer', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/opt/python/training/ggt.py b/tensorflow/contrib/opt/python/training/ggt.py new file mode 100644 index 0000000000000000000000000000000000000000..928c453517f825ed2d305ec498d07ac29c065f1a --- /dev/null +++ b/tensorflow/contrib/opt/python/training/ggt.py @@ -0,0 +1,312 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""GGT for Tensorflow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import numpy as np +from tensorflow.contrib.optimizer_v2 import optimizer_v2 +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops + + +class GGTOptimizer(optimizer_v2.OptimizerV2): + """Optimizer that implements the GGT algorithm. + + GGT has an advantage over sgd and adam on large models with poor conditioning, + for example language models and CNNs, + see [ABCHSZZ 2018]([pdf](https://arxiv.org/pdf/1806.02958.pdf)). + """ + + def __init__(self, + learning_rate=0.001, + beta1=0.9, + use_locking=False, + name="GGT", + window=10, + eps=1e-4, + svd_eps=1e-6, + sigma_eps=1e-2): + """Construct a new GGT optimizer. + + Initialization: + + ``` + t <- 0 (Initialize timestep) + grad_buffer <- 0 (Initialize buffer for keeping past gradients) + flat_grad <- 0 (Initialize flattened gradient that contains gradients of all + variables) + m_0 <- 0 (Initialize 1st moment vector) + ``` + + Suppose all variables and their gradients are concatenated into vectors + `flat_vars` and `flat_grad`. The update rule for `flat_vars` + uses an optimization described at the beginning of section 2 of the paper: + + ``` + t <- t + 1 + + m_t <- beta1 * m_{t-1} + (1 - beta1) * flat_grad + grad_buffer[(t-1) % window, :] <- m_t + + M <- grad_buffer^T / sqrt(min(t, window)) + U, sigma, _ <- SVD(M^TM + I * svd_eps) + + sigma_sqrt_inv <- (sqrt(sigma) + sigma_eps)^(-3) + sigma_sqrt_min <- min(sqrt(sigma)) + + if sigma_sqrt_min > eps: + new_step <- M U diag(sigma_sqrt_inv) U^T M^T m_t + + (m_t - M U diag(1/sigma) U^T M^T m_t) / sigma_sqrt_min + else: + new_step <- M U diag(sigma_sqrt_inv) U^T M^T m_t + + flat_vars <- flat_vars - learning_rate * new_step + ``` + + GGT provides the power of full-matrix adaptive regularization at a cost not + much larger than SGD. As a result it is suited for large models where the + gradient covariance matrix has a poor condition number that slows down first + order methods. + GGT uses the preconditioner from full-matrix AdaGrad, with gradient history + attenuated exponentially as in Adam, and truncated to a window parameter. + It has provable guarantees even for non-convex optimization that is never + significantly worse than SGD and in some cases better. + + Args: + learning_rate: A float hyperparameter. The learning rate. + beta1: A float hyperparameter. The exponential decay rate for the 1st + moment estimates. + use_locking: If True use locks for update operations. + name: Optional name for the operations created when applying gradients. + Defaults to "GGT". + window: An integer hyperparameter. The number of first moments to keep in + computing the adaptive preconditioner. + eps: A float hyperparameter. Used to truncate small eigenvalues of the + gradient covariance matrix. + svd_eps: A float hyperparameter. Used to stabilize SVD. + sigma_eps: A float hyperparameter. Used to regularize matrix inversion. + """ + super(GGTOptimizer, self).__init__(use_locking, name) + self._set_hyper("lr", learning_rate) + self._set_hyper("beta1", beta1) + self._set_hyper("window", window) + self._set_hyper("eps", eps) + self._set_hyper("svd_eps", svd_eps) + self._set_hyper("sigma_eps", sigma_eps) + + self.index_dict = {} + self.shape_dict = {} + + def _create_vars(self, var_list, state): + # Construct ordered dictionary for variable dimensions, sorted by name. + shape_dict = {} + for v in var_list: + shape_dict[v.name] = np.prod(v.get_shape()).value + self.shape_dict = collections.OrderedDict( + sorted(shape_dict.items(), key=lambda t: t[0])) + + # Assign each variable its location in flat_grad. The locations are based on + # the order of sorted names. + idx = 0 + for v_name, v_dim in self.shape_dict.items(): + self.index_dict[v_name] = idx + idx += v_dim + + state.create_non_slot( + initial_value=math_ops.cast(0., dtype=var_list[0].dtype.base_dtype), + name="global_step") + + # Buffer for keeping past gradients. + window = state.get_hyper("window") + grad_buffer_init = array_ops.zeros( + [window, idx], dtype=var_list[0].dtype.base_dtype) + state.create_non_slot(initial_value=grad_buffer_init, name="grad_buffer") + + state.create_non_slot( + initial_value=array_ops.zeros( + (idx,), dtype=var_list[0].dtype.base_dtype), + name="moment1") + + # Flattened gradient that contains gradients for all variables in the model. + state.create_non_slot( + initial_value=array_ops.zeros( + (idx,), dtype=var_list[0].dtype.base_dtype), + name="flat_grad") + + def _get_global_step(self, state=None): + if state is None: + state = self._get_per_graph_state() + return state.get_non_slot("global_step") + + def _get_moment1(self, state=None): + if state is None: + state = self._get_per_graph_state() + return state.get_non_slot("moment1") + + def _get_grad_buffer(self, state=None): + if state is None: + state = self._get_per_graph_state() + return state.get_non_slot("grad_buffer") + + def _get_flat_grad(self, state=None): + if state is None: + state = self._get_per_graph_state() + return state.get_non_slot("flat_grad") + + def _apply_sparse(self, grad, var): + raise NotImplementedError("Sparse gradient updates are not supported.") + + def _prepare(self, state): + self._variables = [] + + def _apply_dense(self, grad, var, state): + self._variables.append(var) + dim = self.shape_dict[var.name] + start_index = self.index_dict[var.name] + end_index = start_index + dim + + # Update flat_gradient at the index associated with the variable. + flat_grad = self._get_flat_grad(state) + new_flat_grad = array_ops.reshape(grad, [-1]) + flat_grad_updated = state_ops.scatter_update( + flat_grad, math_ops.range(start_index, end_index), new_flat_grad) + + return flat_grad_updated + + def _resource_apply_dense(self, grad, var, state): + self._variables.append(var) + dim = self.shape_dict[var.name] + start_index = self.index_dict[var.name] + end_index = start_index + dim + + # Update flat_gradient at the index associated with the variable. + flat_grad = self._get_flat_grad(state) + new_flat_grad = array_ops.reshape(grad, [-1]) + flat_grad_updated = state_ops.scatter_update( + flat_grad, math_ops.range(start_index, end_index), new_flat_grad) + + return flat_grad_updated + + def _finish(self, state): + var_dtype = self._variables[0].dtype.base_dtype + # Update global step. + global_step = self._get_global_step(state) + update_global_step = state_ops.assign_add(global_step, 1.) + + # Update the first moment estimate. + beta1 = state.get_hyper("beta1", dtype=var_dtype) + moment1 = self._get_moment1(state) + flat_grad = self._get_flat_grad(state) + # moment1_t := beta1 * moment1_{t-1} + (1 - beta1) * flat_grad_t + update_moment1 = moment1.assign(beta1 * moment1 + (1. - beta1) * flat_grad) + + # Update the gradient buffer. + window = state.get_hyper("window") + grad_buffer = self._get_grad_buffer(state) + next_grad_index = math_ops.floormod( + math_ops.to_int32(update_global_step - 1.), window) + # grad_buffer[(t-1) % window] := moment1_t + update_grad_buffer = state_ops.scatter_update(grad_buffer, next_grad_index, + update_moment1) + + # Compute the update step. + eps = state.get_hyper("eps", dtype=var_dtype) + svd_eps = state.get_hyper("svd_eps", dtype=var_dtype) + sigma_eps = state.get_hyper("sigma_eps", dtype=var_dtype) + lr = state.get_hyper("lr", dtype=var_dtype) + denom = math_ops.sqrt( + math_ops.minimum( + ops.convert_to_tensor(update_global_step), + ops.convert_to_tensor(math_ops.cast(window, dtype=var_dtype)))) + moment1_2d = array_ops.expand_dims(update_moment1, -1) + + # m = grad_buffer^T / sqrt(min(t, window)) + # m has shape [model dimension, window], where model dimension is the sum + # of the dimensions of the flattened variables. + m = array_ops.transpose(math_ops.divide(update_grad_buffer, denom)) + + # sigma, u, _ = SVD(m^Tm + I * svd_eps) + mm = math_ops.matmul(m, m, transpose_a=True) + damping = math_ops.cast(linalg_ops.eye(window), dtype=var_dtype) * svd_eps + sigma, u, _ = linalg_ops.svd(mm + damping) + sigma_sqrt = math_ops.sqrt(sigma) + sigma_sqrt_min = math_ops.reduce_min(sigma_sqrt) + + # sigma_sqrt_inv = 1 / (\sqrt{sigma} + sigma_eps) ^ 3 + # We add sigma_eps to alleviate numerical instability. + # Note that (m^Tm)^(-3/2) = u diag(sigma_sqrt_inv) u^T. + sigma_sqrt_inv = math_ops.divide( + math_ops.cast(1.0, dtype=var_dtype), + math_ops.pow(sigma_sqrt + sigma_eps, 3)) + + # In full matrix AdaGrad, the update step computes (mm^T)^(-1/2)g, where the + # inversion of a model dimension by model dimension matrix is needed. To + # speed up this computation we calculate the following instead: + # m(m^Tm)^(-3/2)m^T moment1 = m u diag(sigma_sqrt_inv) u^T m^T moment1. + new_step = array_ops.expand_dims( + array_ops.zeros(flat_grad.get_shape(), dtype=var_dtype), -1) + head = math_ops.matmul( + m, + math_ops.matmul( + u, + math_ops.matmul( + array_ops.diag(sigma_sqrt_inv), + math_ops.matmul( + u, + math_ops.matmul(m, moment1_2d, transpose_a=True), + transpose_a=True)))) + + # When inverting (mm^t)^(1/2), we also add epsilon * I regularization for + # degenerate cases. We expand ((mm^t)^(1/2) + epsilon * I)^(-1) using + # Woodbury's identity. + # For full derivation please see paper at + # https://arxiv.org/pdf/1806.02958.pdf + tail = moment1_2d - math_ops.matmul( + m, + math_ops.matmul( + u, + math_ops.matmul( + array_ops.diag( + math_ops.divide(math_ops.cast(1.0, dtype=var_dtype), + sigma)), + math_ops.matmul( + u, + math_ops.matmul(m, moment1_2d, transpose_a=True), + transpose_a=True)))) + scaled_tail = math_ops.divide(tail, sigma_sqrt_min) + + update_new_step = control_flow_ops.cond( + sigma_sqrt_min > eps, lambda: math_ops.add(head, scaled_tail), + lambda: math_ops.add(new_step, head)) + + # Update each variable. + update_step = [] + for var in self._variables: + dim = self.shape_dict[var.name] + start_index = self.index_dict[var.name] + end_index = start_index + dim + var_update_correct_shape = array_ops.reshape( + update_new_step[start_index:end_index], var.get_shape()) + var_updated = state_ops.assign_sub(var, lr * var_update_correct_shape) + update_step.append(var_updated) + + return control_flow_ops.group(update_step) diff --git a/tensorflow/contrib/opt/python/training/ggt_test.py b/tensorflow/contrib/opt/python/training/ggt_test.py new file mode 100644 index 0000000000000000000000000000000000000000..42162960b049cd90c663989fb4fc9d7f179a84ff --- /dev/null +++ b/tensorflow/contrib/opt/python/training/ggt_test.py @@ -0,0 +1,183 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for GGTOptimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.contrib.opt.python.training.ggt import GGTOptimizer +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def ggt_update_numpy(param, + g_t, + lr, + grad_buffer, + m, + window, + t, + beta1=0.9, + eps=1e-4, + svd_eps=1e-6, + sigma_eps=1e-2): + """Tests the correctness of one step of GGT.""" + m_t = m * beta1 + (1 - beta1) * g_t + grad_buffer[((t - 1) % window), :] = m_t + m_matrix = np.transpose(grad_buffer / np.sqrt(np.minimum(t, window))) + mm = np.dot(np.transpose(m_matrix), m_matrix) + damping = np.eye(window) * svd_eps + u, sigma, _ = np.linalg.svd(mm + damping) + + sigma_sqrt_inv = np.power(np.sqrt(sigma) + sigma_eps, -3) + new_step = np.linalg.multi_dot([ + m_matrix, u, + np.diag(sigma_sqrt_inv), + np.transpose(u), + np.transpose(m_matrix), m_t + ]) + + sigma_sqrt_min = np.sqrt(sigma).min() + + if sigma_sqrt_min > eps: + new_step += (m_t - np.linalg.multi_dot([ + m_matrix, u, + np.diag(1.0 / sigma), + np.transpose(u), + np.transpose(m_matrix), m_t + ])) * (1.0 / sigma_sqrt_min) + + param_t = param - lr * new_step + return param_t, m_t, grad_buffer + + +class GGTOptimizerTest(test.TestCase): + + def doTestBasic(self, use_resource=False): + # SVD does not support float16 + for i, dtype in enumerate([dtypes.float32, dtypes.float64]): + with self.test_session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0 = 0.0 + window = 3 + grad_buffer = np.zeros((window, 4), dtype=dtype.as_numpy_dtype) + lr = 0.001 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + var0 = variables.Variable(var0_np, name="var0") + var1 = variables.Variable(var1_np, name="var1") + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + opt = GGTOptimizer(learning_rate=lr, window=window) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + opt_variables = opt.variables() + + m_t = opt._get_moment1() + grad_buffer_t = opt._get_grad_buffer() + g_t = opt._get_flat_grad() + self.assertTrue(m_t is not None) + self.assertTrue(grad_buffer_t is not None) + self.assertTrue(g_t is not None) + self.assertIn(m_t, opt_variables) + self.assertIn(grad_buffer_t, opt_variables) + self.assertIn(g_t, opt_variables) + + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + m_t = opt._get_moment1() + grad_buffer_t = opt._get_grad_buffer() + g_t = opt._get_flat_grad() + + # Run 3 steps of GGT + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + elif t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + if t == 1: + self.assertAllCloseAccordingToType( + np.array([0.01, 0.01, 0.001, 0.001]), self.evaluate(m_t)) + self.assertAllCloseAccordingToType( + np.array([[0.01, 0.01, 0.001, 0.001], [0., 0., 0., 0.], + [0., 0., 0., 0.]]), self.evaluate(grad_buffer_t)) + elif t == 2: + self.assertAllCloseAccordingToType( + np.array([0.019, 0.019, 0.0019, 0.0019]), self.evaluate(m_t)) + self.assertAllCloseAccordingToType( + np.array([[0.01, 0.01, 0.001, 0.001], + [0.019, 0.019, 0.0019, 0.0019], [0., 0., 0., 0.]]), + self.evaluate(grad_buffer_t)) + else: + self.assertAllCloseAccordingToType( + np.array([0.0271, 0.0271, 0.00271, 0.00271]), + self.evaluate(m_t)) + self.assertAllCloseAccordingToType( + np.array([[0.01, 0.01, 0.001, + 0.001], [0.019, 0.019, 0.0019, 0.0019], + [0.0271, 0.0271, 0.00271, 0.00271]]), + self.evaluate(grad_buffer_t)) + + self.assertAllCloseAccordingToType([0.1, 0.1, 0.01, 0.01], + self.evaluate(g_t)) + + var_np = np.append(var0_np, var1_np) + grads_np = np.append(grads0_np, grads1_np) + var_np, m0, grad_buffer = ggt_update_numpy(var_np, grads_np, lr, + grad_buffer, m0, window, t) + + var0_np = var_np[:2] + var1_np = var_np[2:] + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + + def testBasic(self): + with self.test_session(): + self.doTestBasic(use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTestBasic(use_resource=True) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py new file mode 100644 index 0000000000000000000000000000000000000000..8aa40aeb45d4ec15140bdfc5ebd824e8aa08d8d9 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py @@ -0,0 +1,326 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Base class to make optimizers weight decay ready.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.training import optimizer +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.training import adam +from tensorflow.python.training import momentum as momentum_opt +from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import resource_variable_ops + + +class DecoupledWeightDecayExtension(object): + """This class allows to extend optimizers with decoupled weight decay. + + It implements the decoupled weight decay described by Loshchilov & Hutter + (https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is + decoupled from the optimization steps w.r.t. to the loss function. + For SGD variants, this simplifies hyperparameter search since it decouples + the settings of weight decay and learning rate. + For adaptive gradient algorithms, it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield better + training loss and generalization error in the paper above. + + This class alone is not an optimizer but rather extends existing + optimizers with decoupled weight decay. We explicitly define the two examples + used in the above paper (SGDW and AdamW), but in general this can extend + any OptimizerX by using + `extend_with_weight_decay(OptimizerX, weight_decay=weight_decay)`. + In order for it to work, it must be the first class the Optimizer with + weight decay inherits from, e.g. + + ```python + class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer): + def __init__(self, weight_decay, *args, **kwargs): + super(AdamWOptimizer, self).__init__(weight_decay, *args, **kwargs). + ``` + + Note that this extension decays weights BEFORE applying the update based + on the gradient, i.e. this extension only has the desired behaviour for + optimizers which do not depend on the value of'var' in the update step! + """ + + def __init__(self, weight_decay, **kwargs): + """Construct the extension class that adds weight decay to an optimizer. + + Args: + weight_decay: A `Tensor` or a floating point value, the factor by which + a variable is decayed in the update step. + decay_var_list: Optional list or tuple or set of `Variable` objects to + decay. + """ + self._decay_var_list = None # is set in minimize or apply_gradients + self._weight_decay = weight_decay + # The tensors are initialized in call to _prepare + self._weight_decay_tensor = None + super(DecoupledWeightDecayExtension, self).__init__(**kwargs) + + def minimize(self, loss, global_step=None, var_list=None, + gate_gradients=optimizer.Optimizer.GATE_OP, + aggregation_method=None, colocate_gradients_with_ops=False, + name=None, grad_loss=None, decay_var_list=None): + """Add operations to minimize `loss` by updating `var_list` with decay. + + This function is the same as Optimizer.minimize except that it allows to + specify the variables that should be decayed using decay_var_list. + If decay_var_list is None, all variables in var_list are decayed. + + For more information see the documentation of Optimizer.minimize. + """ + self._decay_var_list = set(decay_var_list) if decay_var_list else False + return super(DecoupledWeightDecayExtension, self).minimize( + loss, global_step=global_step, var_list=var_list, + gate_gradients=gate_gradients, aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, name=name, + grad_loss=grad_loss) + + def apply_gradients(self, grads_and_vars, global_step=None, name=None, + decay_var_list=None): + """Apply gradients to variables and decay the variables. + + This function is the same as Optimizer.apply_gradients except that it + allows to specify the variables that should be decayed using + decay_var_list. If decay_var_list is None, all variables in var_list + are decayed. + + For more information see the documentation of Optimizer.apply_gradients. + """ + self._decay_var_list = set(decay_var_list) if decay_var_list else False + return super(DecoupledWeightDecayExtension, self).apply_gradients( + grads_and_vars, global_step=global_step, name=name) + + def _prepare(self): + weight_decay = self._weight_decay + if callable(weight_decay): + weight_decay = weight_decay() + self._weight_decay_tensor = ops.convert_to_tensor( + weight_decay, name="weight_decay") + # Call the optimizers _prepare function. + super(DecoupledWeightDecayExtension, self)._prepare() + + def _decay_weights_op(self, var): + if not self._decay_var_list or var in self._decay_var_list: + return var.assign_sub(self._weight_decay * var, self._use_locking) + return control_flow_ops.no_op() + + def _decay_weights_sparse_op(self, var, indices, scatter_add): + if not self._decay_var_list or var in self._decay_var_list: + return scatter_add(var, indices, -self._weight_decay * var, + self._use_locking) + return control_flow_ops.no_op() + + # Here, we overwrite the apply functions that the base optimizer calls. + # super().apply_x resolves to the apply_x function of the BaseOptimizer. + def _apply_dense(self, grad, var): + with ops.control_dependencies([self._decay_weights_op(var)]): + return super(DecoupledWeightDecayExtension, self)._apply_dense(grad, var) + + def _resource_apply_dense(self, grad, var): + with ops.control_dependencies([self._decay_weights_op(var)]): + return super(DecoupledWeightDecayExtension, self)._resource_apply_dense( + grad, var) + + def _apply_sparse(self, grad, var): + scatter_add = state_ops.scatter_add + decay_op = self._decay_weights_sparse_op(var, grad.indices, scatter_add) + with ops.control_dependencies([decay_op]): + return super(DecoupledWeightDecayExtension, self)._apply_sparse( + grad, var) + + def _resource_scatter_add(self, x, i, v, _=None): + # last argument allows for one overflow argument, to have the same function + # signature as state_ops.scatter_add + with ops.control_dependencies( + [resource_variable_ops.resource_scatter_add(x.handle, i, v)]): + return x.value() + + def _resource_apply_sparse(self, grad, var, indices): + scatter_add = self._resource_scatter_add + decay_op = self._decay_weights_sparse_op(var, indices, scatter_add) + with ops.control_dependencies([decay_op]): + return super(DecoupledWeightDecayExtension, self)._resource_apply_sparse( + grad, var, indices) + + +def extend_with_decoupled_weight_decay(base_optimizer): + """Factory function returning an optimizer class with decoupled weight decay. + + Returns an optimizer class. An instance of the returned class computes the + update step of `base_optimizer` and additionally decays the weights. + E.g., the class returned by + `extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)` is equivalent to + `tf.contrib.opt.AdamWOptimizer`. + + The API of the new optimizer class slightly differs from the API of the + base optimizer: + - The first argument to the constructor is the weight decay rate. + - `minimize` and `apply_gradients` accept the optional keyword argument + `decay_var_list`, which specifies the variables that should be decayed. + If `None`, all variables that are optimized are decayed. + + Usage example: + ```python + # MyAdamW is a new class + MyAdamW = extend_with_decoupled_weight_decay(tf.train.AdamOptimizer) + # Create a MyAdamW object + optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001) + sess.run(optimizer.minimize(loss, decay_variables=[var1, var2])) + + Note that this extension decays weights BEFORE applying the update based + on the gradient, i.e. this extension only has the desired behaviour for + optimizers which do not depend on the value of'var' in the update step! + ``` + + Args: + base_optimizer: An optimizer class that inherits from tf.train.Optimizer. + + Returns: + A new optimizer class that inherits from DecoupledWeightDecayExtension + and base_optimizer. + """ + class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension, + base_optimizer): + """Base_optimizer with decoupled weight decay. + + This class computes the update step of `base_optimizer` and + additionally decays the variable with the weight decay being decoupled from + the optimization steps w.r.t. to the loss function, as described by + Loshchilov & Hutter (https://arxiv.org/pdf/1711.05101.pdf). + For SGD variants, this simplifies hyperparameter search since + it decouples the settings of weight decay and learning rate. + For adaptive gradient algorithms, it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield + better training loss and generalization error in the paper above. + """ + + def __init__(self, weight_decay, *args, **kwargs): + # super delegation is necessary here + # pylint: disable=useless-super-delegation + super(OptimizerWithDecoupledWeightDecay, self).__init__( + weight_decay, *args, **kwargs) + # pylint: enable=useless-super-delegation + + return OptimizerWithDecoupledWeightDecay + + +@tf_export("contrib.opt.MomentumWOptimizer") +class MomentumWOptimizer(DecoupledWeightDecayExtension, + momentum_opt.MomentumOptimizer): + """Optimizer that implements the Momentum algorithm with weight_decay. + + This is an implementation of the SGDW optimizer described in "Fixing + Weight Decay Regularization in Adam" by Loshchilov & Hutter + (https://arxiv.org/abs/1711.05101) + ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). + It computes the update step of `train.MomentumOptimizer` and additionally + decays the variable. Note that this is different from adding + L2 regularization on the variables to the loss. Decoupling the weight decay + from other hyperparameters (in particular the learning rate) simplifies + hyperparameter search. + + For further information see the documentation of the Momentum Optimizer. + + Note that this optimizer can also be instantiated as + ```python + extend_with_weight_decay(tf.train.MomentumOptimizer, + weight_decay=weight_decay) + ``` + """ + + def __init__(self, weight_decay, learning_rate, momentum, + use_locking=False, name="MomentumW", use_nesterov=False): + """Construct a new MomentumW optimizer. + + For further information see the documentation of the Momentum Optimizer. + + Args: + weight_decay: A `Tensor` or a floating point value. The weight decay. + learning_rate: A `Tensor` or a floating point value. The learning rate. + momentum: A `Tensor` or a floating point value. The momentum. + use_locking: If `True` use locks for update operations. + name: Optional name prefix for the operations created when applying + gradients. Defaults to "Momentum". + use_nesterov: If `True` use Nesterov Momentum. + See [Sutskever et al., 2013]( + http://jmlr.org/proceedings/papers/v28/sutskever13.pdf). + This implementation always computes gradients at the value of the + variable(s) passed to the optimizer. Using Nesterov Momentum makes the + variable(s) track the values called `theta_t + mu*v_t` in the paper. + + @compatibility(eager) + When eager execution is enabled, learning_rate, weight_decay and momentum + can each be a callable that takes no arguments and returns the actual value + to use. This can be useful for changing these values across different + invocations of optimizer functions. + @end_compatibility + """ + super(MomentumWOptimizer, self).__init__( + weight_decay, learning_rate=learning_rate, momentum=momentum, + use_locking=use_locking, name=name, use_nesterov=use_nesterov) + + +@tf_export("contrib.opt.AdamWOptimizer") +class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer): + """Optimizer that implements the Adam algorithm with weight decay. + + This is an implementation of the AdamW optimizer described in "Fixing + Weight Decay Regularization in Adam" by Loshchilov & Hutter + (https://arxiv.org/abs/1711.05101) + ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). + + It computes the update step of `train.AdamOptimizer` and additionally decays + the variable. Note that this is different from adding L2 regularization on + the variables to the loss: it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield better + training loss and generalization error in the paper above. + + For further information see the documentation of the Adam Optimizer. + + Note that this optimizer can also be instantiated as + ```python + extend_with_weight_decay(tf.train.AdamOptimizer, weight_decay=weight_decay) + ``` + """ + + def __init__(self, weight_decay, learning_rate=0.001, beta1=0.9, beta2=0.999, + epsilon=1e-8, use_locking=False, name="AdamW"): + """Construct a new AdamW optimizer. + + For further information see the documentation of the Adam Optimizer. + + Args: + weight_decay: A `Tensor` or a floating point value. The weight decay. + learning_rate: A Tensor or a floating point value. The learning rate. + beta1: A float value or a constant float tensor. + The exponential decay rate for the 1st moment estimates. + beta2: A float value or a constant float tensor. + The exponential decay rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just before + Section 2.1), not the epsilon in Algorithm 1 of the paper. + use_locking: If True use locks for update operations. + name: Optional name for the operations created when applying gradients. + Defaults to "Adam". + """ + super(AdamWOptimizer, self).__init__( + weight_decay, learning_rate=learning_rate, beta1=beta1, beta2=beta2, + epsilon=epsilon, use_locking=use_locking, name=name) diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py new file mode 100644 index 0000000000000000000000000000000000000000..74d1cdbbdac8724518937d141a976abf9fec6ce3 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py @@ -0,0 +1,190 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for optimizers with weight decay.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import adam +from tensorflow.contrib.opt.python.training import weight_decay_optimizers + +WEIGHT_DECAY = 0.01 + + +def adamw_update_numpy(param, g_t, t, m, v, lr=0.001, beta1=0.9, + beta2=0.999, epsilon=1e-8): + lr_t = lr * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + param_t = (param - lr_t * m_t / (np.sqrt(v_t) + epsilon) - + (param * WEIGHT_DECAY)) + return param_t, m_t, v_t + + +def momentumw_update_numpy(param, g_t, m, lr=0.001, momentum=0.9, **_): + # v, t are not needed for momentum optimizer + m = momentum * m + g_t + param_t = param - lr * m - param * WEIGHT_DECAY + return param_t, m, None + + +class WeightDecayOptimizerTest(test.TestCase): + + def doTest(self, optimizer, update_fn, optimizer_name, slot_name, + use_resource=False, do_sparse=False): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.test_session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + + if do_sparse: + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = ops.IndexedSlices(constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), + constant_op.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = ops.IndexedSlices(constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), + constant_op.constant([2])) + else: + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + opt = optimizer() + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + + if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of the optimizer + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + elif t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + var0_np, m0, v0 = update_fn(var0_np, grads0_np, t=t, m=m0, v=v0) + var1_np, m1, v1 = update_fn(var1_np, grads1_np, t=t, m=m1, v=v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if use_resource: + self.assertEqual("var0_%d/%s:0" % (i, optimizer_name), + opt.get_slot(var=var0, name=slot_name).name) + + +class AdamWOptimizerTest(WeightDecayOptimizerTest): + + @staticmethod + def get_optimizer(): + return weight_decay_optimizers.AdamWOptimizer(WEIGHT_DECAY) + + def testSparse(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", + use_resource=False, do_sparse=True) + + def testResourceSparse(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", + use_resource=True, do_sparse=True) + + def testBasic(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", + use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", + use_resource=True) + + +class MomentumWOptimizerTest(WeightDecayOptimizerTest): + + @staticmethod + def get_optimizer(): + return weight_decay_optimizers.MomentumWOptimizer(WEIGHT_DECAY, 0.001, 0.9) + + def testSparse(self): + self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", + "momentum", use_resource=False, do_sparse=True) + + def testResourceSparse(self): + self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", + "momentum", use_resource=True, do_sparse=True) + + def testBasic(self): + self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", + "momentum", use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", + "momentum", use_resource=True) + + +class ExtendWithWeightDecayTest(WeightDecayOptimizerTest): + + @staticmethod + def get_optimizer(): + AdamW = weight_decay_optimizers.extend_with_decoupled_weight_decay( + adam.AdamOptimizer) + return AdamW(WEIGHT_DECAY) + + def testBasic(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m", + use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m", + use_resource=True) + + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/optimizer_v2/adam.py b/tensorflow/contrib/optimizer_v2/adam.py index d538ad0fb02699ed8514f512208914f629a47436..631d4f44dfb646541244bfe1d15136dd29f02703 100644 --- a/tensorflow/contrib/optimizer_v2/adam.py +++ b/tensorflow/contrib/optimizer_v2/adam.py @@ -103,9 +103,9 @@ class AdamOptimizer(optimizer_v2.OptimizerV2): def _create_vars(self, var_list, state): # Non-slot variables end up on the same device(s). - state.create_non_slot(initial_value=state.get_hyper("beta1"), + state.create_non_slot(initial_value=lambda: state.get_hyper("beta1"), name="beta1_power") - state.create_non_slot(initial_value=state.get_hyper("beta2"), + state.create_non_slot(initial_value=lambda: state.get_hyper("beta2"), name="beta2_power") # Create slots for the first and second moments. diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index f537318b32986c941b6c41eb363929e906027dd7..c6f3bd6ee18fa353944e2fc303573894933f5b27 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -162,12 +162,12 @@ def _get_processor(v): def _var_key_v2(var): """Key for representing a primary variable, for looking up slots.""" # pylint: disable=protected-access - if hasattr(var, "_mirrored_container"): - mirrored_container = var._mirrored_container() - assert mirrored_container is not None + if hasattr(var, "_distributed_container"): + distributed_container = var._distributed_container() + assert distributed_container is not None if context.executing_eagerly(): - return mirrored_container._unique_id - return mirrored_container._shared_name + return distributed_container._unique_id + return distributed_container._shared_name if context.executing_eagerly(): return var._unique_id return var.op.name @@ -211,8 +211,9 @@ class _OptimizerV2State(object): # This dict starts with a single item with key "None" with the hyper # parameter value converted to a Tensor. Other items have dtype keys # with that Tensor cast to that dtype. - self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)} - for name, (dynamic, value) in hyper.items() if not dynamic} + with ops.init_scope(): + self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)} + for name, (dynamic, value) in hyper.items() if not dynamic} self._slots = {} self._non_slot_dict = {} # Extra state to help Optimizers implement Checkpointable. Holds information diff --git a/tensorflow/contrib/periodic_resample/BUILD b/tensorflow/contrib/periodic_resample/BUILD index 976b312e8345a801ad07f622b6117b88af2cf603..f2171efc959362c1e4392fefbd5842f0883571d7 100644 --- a/tensorflow/contrib/periodic_resample/BUILD +++ b/tensorflow/contrib/periodic_resample/BUILD @@ -97,6 +97,8 @@ tf_cc_test( ], deps = [ ":all_ops", + "//tensorflow/core:framework", + "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", ], diff --git a/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc b/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc index 55edf76fcd3eed461e1465b569e1c2e9e2facbc0..43b7c1799ffb2e27f9d15bc6011d49334867b6ec 100644 --- a/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc +++ b/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/shape_inference_testutil.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/contrib/recurrent/BUILD b/tensorflow/contrib/recurrent/BUILD index b3cb04ce26d96333f516f1298c8d5c331964f05b..f9827f766da022b184b3348fc24b1570bac8678f 100644 --- a/tensorflow/contrib/recurrent/BUILD +++ b/tensorflow/contrib/recurrent/BUILD @@ -102,5 +102,8 @@ cuda_py_tests( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], - tags = ["nopip"], + tags = [ + "nopip", + "optonly", + ], ) diff --git a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py index 03d6da7765ba5249a9fb22f56a469cf07c310479..f10d78259a3be3a3a6f7f78c196ab107f18a53aa 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py @@ -147,7 +147,7 @@ class SpectralOpsTest(test.TestCase): inverse_stft = spectral_ops.inverse_stft(stft, frame_length=8, fft_length=16, frame_step=8) expected_length = (stft.shape[0] - 1) * 8 + 8 - self.assertAllEqual([None], inverse_stft.shape.as_list()) + self.assertAllEqual([256], inverse_stft.shape.as_list()) self.assertAllEqual([expected_length], inverse_stft.eval().shape) def test_stft_and_inverse_stft(self): diff --git a/tensorflow/contrib/signal/python/kernel_tests/test_util.py b/tensorflow/contrib/signal/python/kernel_tests/test_util.py index 9a3603b6a97ef7c3a4b940b83281ebceda93c9db..7d6289532addfd4b4b867bf64d9113253bd1c76d 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/test_util.py +++ b/tensorflow/contrib/signal/python/kernel_tests/test_util.py @@ -39,6 +39,7 @@ def grappler_optimize(graph, fetches=None, rewriter_config=None): """ if rewriter_config is None: rewriter_config = rewriter_config_pb2.RewriterConfig() + rewriter_config.min_graph_nodes = -1 if fetches is not None: for fetch in fetches: graph.add_to_collection('train_op', fetch) diff --git a/tensorflow/contrib/solvers/python/ops/linear_equations.py b/tensorflow/contrib/solvers/python/ops/linear_equations.py index 9305c6a11c4ec898c82553773e8e7277a54ab82e..85918bf8506623cf5e0c9106ae9ed80e233f5a7d 100644 --- a/tensorflow/contrib/solvers/python/ops/linear_equations.py +++ b/tensorflow/contrib/solvers/python/ops/linear_equations.py @@ -28,7 +28,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import linalg_ops def conjugate_gradient(operator, diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index a5d8b061b6b26f9d05be40a1162481ae219b0e9c..adda0b758b172f5e80c165e4b28dbdbecef2ba16 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -49,7 +49,6 @@ tf_cuda_cc_test( tf_custom_op_library( name = "python/ops/_trt_engine_op.so", srcs = [ - "ops/trt_calib_op.cc", "ops/trt_engine_op.cc", ], deps = [ @@ -76,11 +75,9 @@ tf_cuda_library( cc_library( name = "trt_engine_op_kernel", srcs = [ - "kernels/trt_calib_op.cc", "kernels/trt_engine_op.cc", ], hdrs = [ - "kernels/trt_calib_op.h", "kernels/trt_engine_op.h", ], copts = tf_copts(), @@ -89,20 +86,22 @@ cc_library( ":trt_logging", ":trt_plugins", ":trt_resources", + ":trt_conversion", + ":utils", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:stream_executor_headers_lib", + "//tensorflow/core/grappler/costs:graph_properties", ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]) + tf_custom_op_library_additional_deps(), - # TODO(laigd) + # TODO(laigd): fix this by merging header file in cc file. alwayslink = 1, # buildozer: disable=alwayslink-with-hdrs ) tf_gen_op_libs( op_lib_names = [ "trt_engine_op", - "trt_calib_op", ], ) @@ -122,7 +121,6 @@ tf_gen_op_wrapper_py( name = "trt_engine_op", gen_locally = True, deps = [ - ":trt_calib_op_op_lib", ":trt_engine_op_op_lib", ":trt_logging", ":trt_shape_function", @@ -140,7 +138,6 @@ tf_custom_op_py_library( kernels = [ ":trt_engine_op_kernel", ":trt_engine_op_op_lib", - ":trt_calib_op_op_lib", ":trt_shape_function", ], srcs_version = "PY2AND3", @@ -191,7 +188,6 @@ tf_py_wrap_cc( deps = [ ":trt_conversion", ":trt_engine_op_kernel", - "//tensorflow/core:framework_lite", "//third_party/python_runtime:headers", ], ) @@ -211,6 +207,7 @@ tf_cuda_library( ], deps = [ ":trt_logging", + ":utils", "//tensorflow/core:framework_headers_lib", "//tensorflow/core:framework_lite", "//tensorflow/core:lib_proto_parsing", @@ -237,12 +234,12 @@ tf_cuda_library( ":trt_plugins", ":trt_logging", ":trt_resources", + ":utils", "//tensorflow/core/grappler/clusters:cluster", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", - "//tensorflow/core:framework", "//tensorflow/core:gpu_runtime", "//tensorflow/core:framework_lite", "//tensorflow/core:graph", @@ -343,3 +340,8 @@ py_test( "//tensorflow/python:framework_test_lib", ], ) + +cc_library( + name = "utils", + hdrs = ["convert/utils.h"], +) diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index da4dd5a14cd74591fc9df63cd5868044e4e369ec..1c4fd4a0ce1972b92d9d96347ce540d773e04422 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/convert/convert_graph.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include #include #include #include @@ -24,10 +24,17 @@ limitations under the License. #include #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" +#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" #include "tensorflow/contrib/tensorrt/segment/segment.h" #include "tensorflow/core/common_runtime/gpu/gpu_id.h" #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" #include "tensorflow/core/common_runtime/gpu/process_state.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -39,17 +46,39 @@ limitations under the License. #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" // NOLINT #include "tensorflow/core/protobuf/device_properties.pb.h" // NOLINT +#include "tensorflow/core/protobuf/rewriter_config.pb.h" // NOLINT +#include "tensorflow/core/util/device_name_utils.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT +#include "cuda/include/cuda_runtime_api.h" #include "tensorrt/include/NvInfer.h" - namespace tensorflow { namespace tensorrt { namespace convert { +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + +// Returns compiled TRT version information {Maj, Min, Patch} +std::vector GetLinkedTensorRTVersion() { + return {NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR, NV_TENSORRT_PATCH}; +} + +// Returns loaded TRT library version {Maj, Min, Patch} +std::vector GetLoadedTensorRTVersion() { + int ver = getInferLibVersion(); + int ver_major = ver / 1000; + ver = ver - ver_major * 1000; + int ver_minor = ver / 100; + int ver_patch = ver - ver_minor * 100; + return {ver_major, ver_minor, ver_patch}; +} + namespace { bool IsTensorRTCandidate(const tensorflow::Node* node) { @@ -82,229 +111,6 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) { PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); } -void GetSubGraphIncomingEdges(const tensorflow::Graph& graph, - const std::set& subgraph_node_ids, - tensorflow::EdgeSet* incoming_edges) { - for (int node_id : subgraph_node_ids) { - const tensorflow::Node* node = graph.FindNodeId(node_id); - for (const tensorflow::Edge* edge : node->in_edges()) { - if (!subgraph_node_ids.count(edge->src()->id()) && - !edge->src()->IsSource() && !edge->IsControlEdge()) { - incoming_edges->insert(edge); - VLOG(2) << "INCOMING " << edge->src()->name() << " -> " << node->name() - << " Y, "; - } else { - VLOG(2) << "INCOMING " << edge->src()->name() << " -> " << node->name() - << " N, "; - } - } - } -} - -void GetSubGraphOutgoingEdges(const tensorflow::Graph& graph, - const std::set& subgraph_node_ids, - tensorflow::EdgeSet* outgoing_edges) { - for (int node_id : subgraph_node_ids) { - const tensorflow::Node* node = graph.FindNodeId(node_id); - for (const tensorflow::Edge* edge : node->out_edges()) { - if (!subgraph_node_ids.count(edge->dst()->id()) && - !edge->dst()->IsSink() && !edge->IsControlEdge()) { - VLOG(2) << "OUTGOING " << node->name() << " -> " << edge->dst()->name() - << " Y, "; - outgoing_edges->insert(edge); - } else { - VLOG(2) << "OUTGOING " << node->name() << " -> " << edge->dst()->name() - << " N, "; - } - } - } -} - -std::pair ParseTensorName(const string& name, - int default_idx = 0) { - string name_no_idx = name; - int idx = default_idx; - const size_t sep = name_no_idx.find_last_of(':'); - if (sep != string::npos) { - name_no_idx = name_no_idx.substr(0, sep); - idx = std::stoi(name.substr(sep + 1)); - } - return std::make_pair(name_no_idx, idx); -} - -std::unordered_map> BuildTensorNameMap( - const std::vector& tensor_names) { - std::unordered_map> result; - for (const string& tensor_name : tensor_names) { - string node_name; - int index; - std::tie(node_name, index) = ParseTensorName(tensor_name); - result[node_name].push_back(index); - } - return result; -} - -// TODO(sami): convert references to pointers -struct ConvertGraphParams { - ConvertGraphParams( - tensorflow::Graph& inp_graph, - const std::vector& output_node_names, - const std::set& subgraph_node_id_numbers, - size_t max_supported_batch_size, size_t max_consumed_workspace_size_bytes, - const tensorflow::grappler::GraphProperties& current_graph_properties, - std::unordered_map>* output_edges, - int engine_precision_mode, const string& device_name, - std::shared_ptr allocator, int cuda_gpu_id) - : graph(inp_graph), - output_names(output_node_names), - subgraph_node_ids(subgraph_node_id_numbers), - max_batch_size(max_supported_batch_size), - max_workspace_size_bytes(max_consumed_workspace_size_bytes), - graph_properties(current_graph_properties), - output_edge_map(output_edges), - precision_mode(engine_precision_mode), - device_name_(device_name), - allocator_(allocator), - cuda_gpu_id_(cuda_gpu_id) {} - tensorflow::Graph& graph; - const std::vector& output_names; - const std::set& subgraph_node_ids; - size_t max_batch_size; - size_t max_workspace_size_bytes; - const tensorflow::grappler::GraphProperties& graph_properties; - std::unordered_map>* output_edge_map; - int precision_mode; - string device_name_; - std::shared_ptr allocator_; - int cuda_gpu_id_; - std::vector> subgraph_inputs; - std::vector> subgraph_outputs; - tensorflow::EdgeSet subgraph_incoming_edges; - tensorflow::EdgeSet subgraph_outgoing_edges; -}; - -static tensorflow::Status FillSubGraphEdgeSets(ConvertGraphParams* p) { - GetSubGraphIncomingEdges(p->graph, p->subgraph_node_ids, - &p->subgraph_incoming_edges); - - std::set> unique_tensors; - // Add only unique input source nodes. If output of an outside node is shared - // between multiple nodes inside the engine, only one edge should be created - for (const tensorflow::Edge* edge : p->subgraph_incoming_edges) { - unique_tensors.insert({edge->src()->id(), edge->src_output()}); - } - p->subgraph_inputs.insert(p->subgraph_inputs.begin(), unique_tensors.begin(), - unique_tensors.end()); - GetSubGraphOutgoingEdges(p->graph, p->subgraph_node_ids, - &p->subgraph_outgoing_edges); - unique_tensors.clear(); - // Similar to above, if multiple ouside nodes are sharing the output of an - // internal node only one output port should be created and shared between - // outputs - for (const tensorflow::Edge* edge : p->subgraph_outgoing_edges) { - unique_tensors.insert({edge->src()->id(), edge->src_output()}); - } - p->subgraph_outputs.reserve(unique_tensors.size()); - p->subgraph_outputs.insert(p->subgraph_outputs.begin(), - unique_tensors.begin(), unique_tensors.end()); - return tensorflow::Status::OK(); -} - -tensorflow::Status GetCalibNode(ConvertGraphParams* params) { - TF_RETURN_IF_ERROR(FillSubGraphEdgeSets(params)); - tensorflow::NodeDef trt_node_def; - SubGraphParams s(params->graph, params->subgraph_node_ids, - params->subgraph_inputs, params->subgraph_outputs, - params->max_batch_size, params->max_workspace_size_bytes, - params->graph_properties, params->output_edge_map, - &trt_node_def, params->precision_mode, params->device_name_, - params->allocator_, params->cuda_gpu_id_); - TF_RETURN_IF_ERROR(InjectCalibrationNode(s)); - tensorflow::Status status; - tensorflow::Node* trt_node = params->graph.AddNode(trt_node_def, &status); - - TF_RETURN_IF_ERROR(status); - - for (auto in_edge : - params->subgraph_incoming_edges) { // loop over incoming edges and - // attach them to calib node - auto src_output = in_edge->src_output(); - auto dst_node = in_edge->dst(); - auto dst_input = in_edge->dst_input(); - VLOG(1) << " update edge " << trt_node->name() << ":" << src_output - << " -> " << dst_node->name() << ":" << dst_input; - TF_RETURN_IF_ERROR( - params->graph.UpdateEdge(trt_node, src_output, dst_node, dst_input)); - } - return tensorflow::Status::OK(); -} - -tensorflow::Status ConvertSubGraphToTensorRT(ConvertGraphParams* params) { - TF_RETURN_IF_ERROR(FillSubGraphEdgeSets(params)); - tensorflow::NodeDef trt_node_def; - - SubGraphParams s(params->graph, params->subgraph_node_ids, - params->subgraph_inputs, params->subgraph_outputs, - params->max_batch_size, params->max_workspace_size_bytes, - params->graph_properties, params->output_edge_map, - &trt_node_def, params->precision_mode, params->device_name_, - params->allocator_, params->cuda_gpu_id_); - TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(s)); - tensorflow::Status status; - tensorflow::Node* trt_node = params->graph.AddNode(trt_node_def, &status); - - // AddNode does not wire edges. - // Re-map incoming edges to use the new TRT node instead of the orig subgraph - std::map, int> subgraph_edge_to_input_map; - for (size_t i = 0; i < params->subgraph_inputs.size(); ++i) { - subgraph_edge_to_input_map.insert({params->subgraph_inputs.at(i), i}); - } - std::set> unique_tensors; - for (const tensorflow::Edge* edge : params->subgraph_incoming_edges) { - std::pair old_src = {edge->src()->id(), edge->src_output()}; - if (unique_tensors.count(old_src)) continue; - unique_tensors.insert(old_src); - int new_src_output = subgraph_edge_to_input_map.at(old_src); - params->graph.AddEdge(edge->src(), edge->src_output(), trt_node, - new_src_output); - VLOG(1) << "Wire " << edge->src()->name() << ":" << edge->src_output() - << " -> " << trt_node->name() << ":" << new_src_output; - params->graph.RemoveEdge(edge); - } - if (VLOG_IS_ON(2)) { - VLOG(2) << "new edge count: " << trt_node->in_edges().size(); - for (const tensorflow::Edge* edge : trt_node->in_edges()) { - VLOG(2) << edge->src()->name() << " port: " << edge->src_output(); - } - } - TF_RETURN_IF_ERROR(status); - - // Re-map outgoing edges to use the new TRT node instead of the orig subgraph - std::map, int> subgraph_edge_to_output_map; - for (size_t i = 0; i < params->subgraph_outputs.size(); ++i) { - subgraph_edge_to_output_map.insert({params->subgraph_outputs.at(i), i}); - } - TF_RETURN_IF_ERROR(status); - for (const tensorflow::Edge* edge : params->subgraph_outgoing_edges) { - std::pair old_src = {edge->src()->id(), edge->src_output()}; - int new_src_output = subgraph_edge_to_output_map.at(old_src); - TF_RETURN_IF_ERROR(params->graph.UpdateEdge( - trt_node, new_src_output, edge->dst(), edge->dst_input())); - VLOG(1) << "Wire " << trt_node->name() << ":" << new_src_output << " -> " - << edge->dst()->name() << ":" << edge->dst_input(); - } - // Remove the original subgraph - for (int node_id : params->subgraph_node_ids) { - tensorflow::Node* node = params->graph.FindNodeId(node_id); - // Don't remove the input placeholders - if (node->type_string() == "Placeholder") { - continue; - } - params->graph.RemoveNode(node); - } - return tensorflow::Status::OK(); -} - tensorflow::Status BuildNodeMap( const tensorflow::Graph& graph, std::unordered_map* node_map) { @@ -318,51 +124,77 @@ tensorflow::Status BuildNodeMap( } } // namespace + +// Function to get calibration from ResourceMgr and put them into nodedef. tensorflow::Status ConvertCalibGraphToInferGraph( - const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* infer_graph) { + const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* infer_graph, + bool is_dyn_op) { VLOG(0) << "Starting Calib Conversion"; - tensorflow::Graph graph(tensorflow::OpRegistry::Global()); - TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph( - tensorflow::GraphConstructorOptions(), graph_def, &graph)); - // get calib nodes - std::vector calib_nodes; - std::vector topo_order; - tensorflow::GetPostOrder(graph, &topo_order); - for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) { - auto node = *rit; - if (node->type_string() == "TRTCalibOp") { - VLOG(1) << "Found Calib Node " << node->name(); - calib_nodes.push_back(node); - } + infer_graph->CopyFrom(graph_def); + auto trt_rm = TRTResourceManager::instance(); + auto calib_rm = trt_rm->getManager("TRTCalibration"); + int num_nodes = infer_graph->node_size(); + if (!is_dyn_op) { + LOG(WARNING) << "Construction of static int8 engine is not implemented " + "yet!. Dynamic engine will be constructed"; } - VLOG(0) << "Num Calib nodes in graph= " << calib_nodes.size(); - if (calib_nodes.size() == 0) - return tensorflow::errors::FailedPrecondition( - "Graph doesn't contain any calibration nodes!." - " Please generate calibration graph and run calibration first"); - for (auto n : calib_nodes) { - TF_RETURN_IF_ERROR( - tensorrt::convert::ConvertCalibrationNodeToEngineNode(graph, n)); + for (int i = 0; i < num_nodes; ++i) { + auto n = infer_graph->mutable_node(i); + if (n->op() == "TRTEngineOp") { + VLOG(1) << "Processing " << n->name(); + string container_name = n->attr().at("segment_funcdef_name").s(); + TRTCalibrationResource* cres = nullptr; + auto status = calib_rm->Lookup(container_name, "Calibrator", &cres); + if (!status.ok()) { + LOG(ERROR) << "Could not get Calibration information. Did you run with " + "calibration data?"; + return tensorflow::errors::FailedPrecondition( + "Need to run graph with calibration data first!"); + } + if (cres->calibrator_) { + cres->calibrator_->setDone(); + cres->thr_->join(); + const auto& calibration_table = + cres->calibrator_->getCalibrationTableAsString(); + if (!calibration_table.size()) { + LOG(ERROR) << "Calibration table is empty"; + return tensorflow::errors::Unknown( + "Calibration table is missing. This shouldn't have happened!"); + } + n->mutable_attr()->at("calibration_data").set_s(calibration_table); + } else { + LOG(ERROR) << "Can't get TRTCalibrator from resource manager!"; + return tensorflow::errors::Unknown( + "Can't get TRTCalibrator from resource manager!"); + } + cres->Unref(); + } } - graph.ToGraphDef(infer_graph); return tensorflow::Status::OK(); } +// Entry function from Python. tensorflow::Status ConvertGraphDefToTensorRT( const tensorflow::GraphDef& graph_def, const std::vector& output_names, size_t max_batch_size, size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, - int precision_mode = FP32MODE, int minimum_segment_size = 3) { + int precision_mode, int minimum_segment_size, bool is_dyn_op, + int max_cached_engines, std::vector cached_engine_batches) { // optimization pass tensorflow::grappler::GrapplerItem item; item.fetch = output_names; item.graph = graph_def; - + // grappler requires a virtual cluster with a proper GPU device + // in order to calculate flops>0 or fails with FATAL + // We add numbers from a Pascal card here to have flops>0 tensorflow::DeviceProperties device_properties; device_properties.set_type("GPU"); device_properties.mutable_environment()->insert({"architecture", "6"}); - tensorflow::grappler::Cluster* cluster = - new tensorflow::grappler::VirtualCluster({{"/GPU:0", device_properties}}); + device_properties.set_num_cores(3584); + device_properties.set_frequency(1531); + std::unique_ptr cluster( + new tensorflow::grappler::VirtualCluster( + {{"/GPU:0", device_properties}})); // single machine int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores(); @@ -370,134 +202,633 @@ tensorflow::Status ConvertGraphDefToTensorRT( VLOG(2) << "cpu_cores: " << num_cpu_cores; VLOG(2) << "gpus: " << num_gpus; tensorflow::RewriterConfig rw_cfg; + // use only const folding and layout for the time being since new optimizers + // break the graph for us + rw_cfg.add_optimizers("constfold"); + rw_cfg.add_optimizers("layout"); + rw_cfg.set_meta_optimizer_iterations(tensorflow::RewriterConfig::ONE); tensorflow::grappler::MetaOptimizer meta_opt(nullptr, rw_cfg); tensorflow::GraphDef gdef; - TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster, item, &gdef)); + TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster.get(), item, &gdef)); item.graph = gdef; // AJ refactoring shape inference through grappler/GraphProperties. tensorflow::grappler::GraphProperties static_graph_properties(item); TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true)); // Build full graph - - return ConvertAfterShapes(gdef, output_names, max_batch_size, - max_workspace_size_bytes, new_graph_def, - precision_mode, minimum_segment_size, - static_graph_properties, nullptr); + ConversionParams cp; + cp.input_graph_def = &gdef; + cp.output_names = &output_names; + cp.max_batch_size = max_batch_size; + cp.output_graph_def = new_graph_def; + cp.precision_mode = precision_mode; + cp.is_dyn_op = is_dyn_op; + cp.max_cached_engines = max_cached_engines; + cp.cached_engine_batches = cached_engine_batches; + cp.minimum_segment_size = minimum_segment_size; + cp.graph_properties = &static_graph_properties; + cp.max_workspace_size_bytes = max_workspace_size_bytes; + if (VLOG_IS_ON(5)) { + std::fstream f; + f.open("TRTConversionInput.pb", + std::fstream::out | std::fstream::binary | std::fstream::trunc); + f << gdef.SerializeAsString(); + f.close(); + } + return ConvertAfterShapes(cp); } -tensorflow::Status ConvertAfterShapes( - const tensorflow::GraphDef& gdef, const std::vector& output_names, - size_t max_batch_size, size_t max_workspace_size_bytes, - tensorflow::GraphDef* new_graph_def, int precision_mode, - int minimum_segment_size, +// Function to get subsegment information structure. +tensorflow::Status GetEngineInfo( + const tensorflow::Graph* g, const tensorflow::grappler::GraphProperties& graph_properties, - const tensorflow::grappler::Cluster* cluster) { - // Segment the graph into subgraphs that can be converted to TensorRT - tensorflow::tensorrt::segment::SegmentOptions segment_options; + const std::set& segment_nodes, + const std::unordered_map& node_map, + const std::vector& reverse_topo_order, + EngineInfo* info) { + std::vector subgraph_node_ids; + std::set segment_devices; + int input_port = 0; + int output_port = 0; + + // Map from src_node_name+port to the unique port numbers of the TRT op, where + // the src_node_name is the name of the source node of the input/output + // edge, thus there must not be any duplicates since source nodes of + // input/output edges must be in different split of the graph. + // TODO(aaroey): consider using node id and port instead. + std::unordered_map created_edges; + for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend(); + ++it) { + const auto& node_name = (*it)->name(); + + if (segment_nodes.count(node_name) == 0) continue; + auto node = node_map.at(node_name); + auto node_device = node->requested_device(); + if (!node_device.empty()) { + segment_devices.insert(node_device); + } else { + if (node->has_assigned_device_name()) { + segment_devices.insert(node->assigned_device_name()); + } else { + VLOG(2) << "Node " << node->name() + << " neither have requested device nor assigned device"; + } + } + int node_id = node->id(); + subgraph_node_ids.push_back(node_id); + for (const auto edge : node->in_edges()) { + auto input_node = edge->src(); + if (segment_nodes.count(input_node->name()) == 0) { + // Add constant input node into the segment. We don't care if it has + // other output edges going into other engines or TF nodes. Since we add + // it only to the subsegment node list, not the subsegment itself, it + // won't be removed from the graph. If it doesn't have any edges, TF + // will prune it out. + if (input_node->type_string() == "Const") { + subgraph_node_ids.push_back(input_node->id()); + } else if (!edge->IsControlEdge() && !input_node->IsSource()) { + string s(input_node->name()); + StrAppend(&s, ":", edge->src_output()); + VLOG(1) << "Input edge = " << s; + int port = input_port; + if (created_edges.count(s)) { + port = created_edges.at(s); + } else { + created_edges.insert({s, port}); + input_port++; + } + info->connections.emplace_back(input_node->name(), input_node->id(), + edge->src_output(), node_name, node_id, + edge->dst_input(), true, port); + } + } + } + for (const auto edge : node->out_edges()) { + auto output_node = edge->dst(); + if (segment_nodes.count(output_node->name()) == 0 && + !edge->IsControlEdge() && !output_node->IsSink()) { + string s(node_name); + StrAppend(&s, ":", edge->src_output()); + VLOG(1) << "Output edge = " << s; + int port = output_port; + if (created_edges.count(s)) { + port = created_edges.at(s); + } else { + created_edges.insert({s, port}); + output_port++; + } + info->connections.emplace_back(output_node->name(), output_node->id(), + edge->dst_input(), node_name, node_id, + edge->src_output(), false, port); + } + } + } + + TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef( + g, graph_properties, subgraph_node_ids, &info->connections, + &info->segment_graph_def, &info->engine_name)); + // TODO(sami): This should not happen once segmenter is updated. + if (segment_devices.size() == 1) { + info->device = *segment_devices.begin(); + } else if (segment_devices.size() > 1) { + LOG(WARNING) << "Detected multiple(" << segment_devices.size() + << ") devices for the segment. Picking first one to continue " + << "but this shouldn't have happened"; + info->device = *segment_devices.begin(); + } else { + VLOG(1) << "Segment devices size is 0"; + } + return Status::OK(); +} + +// Function to insert a TRT node into the graph. The graph is not modified if +// the returned status is not ok. +// 'alloc' is only used for creating static engine. +tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, + const std::vector& infos, int pos, + nvinfer1::IGpuAllocator* alloc, + int max_batch_size) { + const auto& info = infos.at(pos); + std::vector out_shapes; + std::vector input_shapes; + std::vector shapes; + std::vector inputs; + std::vector out_types; + VLOG(1) << "Processing " << info.engine_name; + + // Update the shape and data types of input/output nodes, and find all unique + // inputs. + for (const auto& conn : info.connections) { + if (!conn.is_input_edge) { + // Set the shapes and data types of output edge. + tensorflow::TensorShapeProto out_shape; + // shape of the output node inside segment + conn.inside_shape.AsProto(&out_shape); + if (out_shapes.size() <= conn.port_number) { + out_shapes.resize(conn.port_number + 1); + out_types.resize(conn.port_number + 1); + } + out_shapes.at(conn.port_number) = out_shape; + out_types.at(conn.port_number) = conn.connection_type; + continue; + } + + // Set the shapes and data types of input edge. + tensorflow::TensorShapeProto in_shape; + conn.outside_shape.AsProto(&in_shape); + if (input_shapes.size() <= conn.port_number) { + input_shapes.resize(conn.port_number + 1); + shapes.resize(conn.port_number + 1); + } + input_shapes.at(conn.port_number) = in_shape; + shapes.at(conn.port_number) = conn.outside_shape; + + string input_node = conn.outside_node_name; + int input_port = conn.outside_port; + bool found_engine = false; + // Rewire the inputs to other engines if they contain original input node. + // Note that we use the information of the engine here, not the information + // of the created TRT nodes, so we're able to find all the connections to + // any other engines beforehand. + for (size_t t = 0; t < infos.size(); ++t) { + if (t == pos) continue; + auto& engine_info = infos.at(t); + for (const auto& eng_conn : engine_info.connections) { + if (eng_conn.is_input_edge) continue; + if (eng_conn.inside_node_name == input_node) { + input_node = engine_info.engine_name; + if (eng_conn.inside_port == input_port) { + input_port = eng_conn.port_number; + found_engine = true; + break; + } + } + } + if (found_engine) break; + } + VLOG(1) << "Engine Input " << input_node << ":" << input_port << " -> " + << info.engine_name << ":" << inputs.size(); + // Skip duplicate inputs. + bool new_input = true; + for (const auto& inp : inputs) { + if (inp.node == input_node && inp.index == input_port) { + new_input = false; + break; + } + } + if (new_input) { + inputs.emplace_back(input_node, input_port, conn.connection_type); + } + } + + // Build the engine and get its serialized representation. + string segment_string; + if (info.engine_type == EngineInfo::EngineType::TRTStatic || + info.precision_mode == INT8MODE) { + // Create static engine for fp32/fp16 mode, and test validity of the engine + // for int8 mode. We don't want engine to fail at the calibration time. + // So we are constructing a FP32 engine here to check its validity, and if + // it is a valid engine then we put the serialized graphdef to the op. + // Otherwise we skip node creation for this engine. + Logger trt_logger; + TrtUniquePtrType engine; + // TODO(sami): What happens if 1st dim is not batch? + TF_RETURN_IF_ERROR(ConvertGraphDefToEngine( + info.segment_graph_def, + info.precision_mode == INT8MODE ? FP32MODE : info.precision_mode, + max_batch_size, info.max_workspace_size_bytes, shapes, &trt_logger, + alloc, /*calibrator=*/nullptr, &engine, + /*convert_successfully=*/nullptr)); + TrtUniquePtrType engine_data(engine->serialize()); + segment_string = + string((const char*)engine_data->data(), engine_data->size()); + if (info.precision_mode == INT8MODE) { + // See above comment about why not putting this inside the 'else' branch. + segment_string = info.segment_graph_def.SerializeAsString(); + } + } else { + segment_string = info.segment_graph_def.SerializeAsString(); + } + + // TODO(aaroey): use enum instead, and add a helper method to do the + // conversion. + string prec_string; + switch (info.precision_mode) { + case FP32MODE: + prec_string = "FP32"; + break; + case FP16MODE: + prec_string = "FP16"; + break; + case INT8MODE: + prec_string = "INT8"; + if (!TRTResourceManager::instance()->getManager("TRTCalibration")) { + LOG(ERROR) << "Failed to construct calibration storage"; + } + break; + default: + return tensorflow::errors::OutOfRange("Unknown precision mode"); + } + tensorflow::NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp"); + if (!info.device.empty()) node_builder.Device(info.device); + if (VLOG_IS_ON(1)) { + string ins = StrCat(info.engine_name, " inputs= "); + for (const auto& ii : inputs) { + StrAppend(&ins, ii.node, ":", ii.index, " "); + } + VLOG(1) << ins; + } + node_builder.Input(inputs); + if (info.engine_type == EngineInfo::EngineType::TRTStatic && + info.cached_engine_batches.size()) { + LOG(WARNING) << "Cached engine batches are ignored for static engines"; + } + tensorflow::NodeDef trt_node; + tensorflow::Status status = + node_builder.Attr("input_shapes", input_shapes) + .Attr("output_shapes", out_shapes) + .Attr("static_engine", + info.engine_type == EngineInfo::EngineType::TRTStatic) + .Attr("segment_funcdef_name", + StrCat(info.engine_name, "_native_segment")) + .Attr("serialized_segment", segment_string) + .Attr("calibration_data", "") + .Attr("max_cached_engines_count", info.maximum_cached_engines) + .Attr("cached_engine_batches", {max_batch_size}) + .Attr("workspace_size_bytes", info.max_workspace_size_bytes) + .Attr("precision_mode", prec_string) + .Attr("OutT", out_types) + .Finalize(&trt_node); + if (!status.ok()) { + LOG(ERROR) << "Node construction failed with" << status; + return status; + } + VLOG(1) << "Adding TRTEngine " << info.engine_name << " to graph"; + + // Up until this point, graph is not modified. If we return !status.ok() from + // here, this segment will be skipped + tensorflow::Node* engine_node = graph->AddNode(trt_node, &status); + if (!status.ok()) { + LOG(ERROR) << "Adding node failed " << status; + return status; + } + // Updates the inputs of output edges destination nodes, and point them to the + // engine node. + for (auto& conn : info.connections) { + if (conn.is_input_edge) continue; + VLOG(1) << " Updating DBG " << engine_node->name() << " out_port " + << conn.port_number << " out_id " << conn.outside_id + << " name=" << conn.outside_node_name; + auto dst_node = graph->FindNodeId(conn.outside_id); + // dst_node can only be removed if it is an input node of another engine. + // In this case, other engines input edge is updated in nodedef to point to + // this engine. Even though edge doesn't exists in the graph, when it is + // deserialized again, correct edges will be constructed. This is a problem + // of graph->AddNode(). + if (!dst_node) continue; + VLOG(1) << "Updating " << engine_node->name() << ":" << conn.port_number + << " to " << dst_node->name() << ":" << conn.outside_port; + auto new_edge = graph->AddEdge(engine_node, conn.port_number, dst_node, + conn.outside_port); + CHECK(new_edge) << "Adding a new edge failed " << engine_node->name() << ":" + << conn.port_number << " -> " << dst_node->name() << ":" + << conn.outside_port; + } + return status; +} + +// Function to construct a funcdef from the segment and add it to the graph. +tensorflow::Status RegisterSegmentFunctionToFunctionLibrary( + tensorflow::Graph* graph, const tensorflow::GraphDef& segment, + const string& name) { + tensorflow::Graph sgraph(graph->flib_def()); + tensorflow::GraphConstructorOptions gcopts; + TF_RETURN_IF_ERROR( + tensorflow::ConvertGraphDefToGraph(gcopts, segment, &sgraph)); + std::map io_nodes; + int num_inputs = 0; + for (auto n : sgraph.op_nodes()) { + if (tensorflow::str_util::StartsWith(n->name(), kInputPHName)) { + num_inputs++; + io_nodes.insert({n->name(), n}); + } else if (tensorflow::str_util::StartsWith(n->name(), kOutputPHName)) { + io_nodes.insert({n->name(), n}); + } + } + + for (int i = 0; i < num_inputs; ++i) { + auto name = StrCat(kInputPHName, i); + auto node = io_nodes[name]; + tensorflow::NodeDef nd; + tensorflow::NodeDefBuilder node_builder( + StrCat(name, "_Arg"), tensorflow::FunctionLibraryDefinition::kArgOp); + VLOG(1) << "Adding " << StrCat(name, "_Arg"); + TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0)) + .Attr("index", i) + .Finalize(&nd)); + tensorflow::Status s; + auto node_arg = sgraph.AddNode(nd, &s); + if (!s.ok()) { + LOG(ERROR) << "Couldn't add _Arg node for " << name; + } + for (auto edge : node->out_edges()) { + sgraph.AddEdge(node_arg, 0, edge->dst(), edge->dst_input()); + VLOG(1) << "Updating funcdef input " << node_arg->name() << ":" << 0 + << " - > " << edge->dst()->name() << ":" << edge->dst_input(); + if (!s.ok()) { + LOG(ERROR) << "Failed to update edge from " << node_arg->name() + << " to " << edge->dst()->name() << ":" << edge->dst_input(); + } + } + sgraph.RemoveNode(node); + } + + for (int i = 0; i < io_nodes.size() - num_inputs; ++i) { + auto name = StrCat(kOutputPHName, i); + auto node = io_nodes[name]; + tensorflow::NodeDef nd; + tensorflow::NodeDefBuilder node_builder( + StrCat(name, "_Ret"), tensorflow::FunctionLibraryDefinition::kRetOp); + auto edge = *(node->in_edges().begin()); + tensorflow::NodeDefBuilder::NodeOut nout( + edge->src()->name(), edge->src_output(), + edge->src()->output_type(edge->src_output())); + VLOG(1) << " input " << nout.node << ":" << nout.index + << " dtype=" << tensorflow::DataTypeString(nout.data_type); + node_builder.Input({nout}); + TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0)) + .Attr("index", i) + .Finalize(&nd)); + if (VLOG_IS_ON(3)) { + VLOG(3) << nd.DebugString(); + } + tensorflow::Status s; + auto node_ret = sgraph.AddNode(nd, &s); + if (!s.ok()) { + LOG(ERROR) << "Couldn't add _Ret node for " << name; + } + VLOG(1) << "Update edge from " << edge->src()->name() << ":" + << edge->src_output() << " - > " << node_ret->name() << ":" << 0; + sgraph.AddEdge(edge->src(), edge->src_output(), node_ret, 0); + s = sgraph.UpdateEdge(edge->src(), edge->src_output(), node_ret, 0); + if (!s.ok()) { + LOG(ERROR) << "Failed to update edge from " << edge->src()->name() << ":" + << edge->src_output() << " - > " << node_ret->name() << ":" + << 0; + } + sgraph.RemoveNode(node); + } + tensorflow::FunctionDefLibrary fdeflib; + auto native_segment = fdeflib.add_function(); + TF_RETURN_IF_ERROR(tensorflow::GraphToFunctionDef( + sgraph, StrCat(name, "_native_segment"), native_segment)); + if (VLOG_IS_ON(7)) { + VLOG(7) << name << " Function_Def "; + VLOG(7) << native_segment->DebugString(); + } + VLOG(1) << "Adding funcdef to graphlib"; + TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdeflib)); + return tensorflow::Status::OK(); +} + +std::pair GetDeviceAndAllocator( + ConversionParams& params, EngineInfo& engine) { + int cuda_device_id = -1; + auto check_device_id = [](int tfid) -> int { + tensorflow::TfGpuId tf_gpu_id(tfid); + CudaGpuId cuda_gpu_id; + Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id); + if (s.ok()) { + VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device " + << cuda_gpu_id.value(); + return cuda_gpu_id.value(); + } + VLOG(2) << "TF GPU with id " << tfid << " do not exist " << s; + return -1; + }; + tensorflow::Allocator* dev_allocator = nullptr; + // we need to us PM here since in python path there is no way to get + // to allocators. + // TODO(sami): when grappler devices become available else path will not be + // necessary + auto pm = tensorflow::ProcessState::singleton(); + if (params.cluster) { // get allocator + tensorflow::Device* device = nullptr; + if (params.cluster->GetDeviceSet()) { + device = params.cluster->GetDeviceSet()->FindDeviceByName(engine.device); + } + if (device) { + tensorflow::AllocatorAttributes alloc_attr; + dev_allocator = device->GetAllocator(alloc_attr); + VLOG(1) << "Using allocator " << dev_allocator->Name(); + } else { + LOG(WARNING) << "Cluster is set but device '" << engine.device + << "' is not found in the cluster"; + } + } else { // cluster not found, possibly a python call + VLOG(1) << "Cluster is not set, probably called from python"; + int found_device = 0; + bool try_gpu_ids = true; + // if device is set, try to find the device. Might be a problem for multi + // host case but TensorRT do not support multi host setups yet. + if (!engine.device.empty()) { + DeviceNameUtils::ParsedName parsed_name; + if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name)) { + cuda_device_id = parsed_name.has_id ? parsed_name.id : -1; + } + try_gpu_ids = !parsed_name.has_id; + } + if (try_gpu_ids) { + while (found_device < 100) { + cuda_device_id = check_device_id(found_device); + if (cuda_device_id >= 0) break; + found_device++; + } + } + if (found_device == 100) { + LOG(ERROR) << " Can't find a GPU device to work with. Please " + "instantiate a session to initialize devices"; + return std::make_pair(cuda_device_id, dev_allocator); + } + LOG(WARNING) + << "Can't determine the device, constructing an allocator at device " + << found_device; + tensorflow::GPUOptions gpuoptions; + // this will be a noop if device is already initialized + gpuoptions.set_allow_growth(true); + tensorflow::TfGpuId tf_gpu_id(found_device); + dev_allocator = pm->GetGPUAllocator(gpuoptions, tf_gpu_id, 1); + } + return std::make_pair(cuda_device_id, dev_allocator); +} + +// Entry function from optimization pass. +tensorflow::Status ConvertAfterShapes(ConversionParams& params) { + // Convert graphdef to graph. tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(), - gdef.library()); + params.input_graph_def->library()); tensorflow::Graph graph(flib); TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph( - tensorflow::GraphConstructorOptions(), gdef, &graph)); + tensorflow::GraphConstructorOptions(), *params.input_graph_def, &graph)); + // Segment the graph into subgraphs that can be converted to TensorRT + tensorflow::tensorrt::segment::SegmentOptions segment_options; // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT) - for (auto node : output_names) { + for (auto node : *(params.output_names)) { segment_options.exclude_node_list.insert(node); } - - // TODO(sami): this should be passed as a knob!!!! - segment_options.minimum_segment_size = minimum_segment_size; - tensorflow::tensorrt::segment::SegmentNodesVector segments; + segment_options.minimum_segment_size = params.minimum_segment_size; + tensorflow::tensorrt::segment::SegmentNodesVector initial_segments; TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph( - &graph, IsTensorRTCandidate, segment_options, &segments)); - if (segments.size() > 1) { - VLOG(0) << "MULTIPLE tensorrt candidate conversion: " << segments.size(); + &graph, IsTensorRTCandidate, segment_options, &initial_segments)); + if (initial_segments.size() > 1) { + VLOG(0) << "MULTIPLE tensorrt candidate conversion: " + << initial_segments.size(); } + + // Get the EngineInfo for each segment. std::unordered_map node_map; TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map)); - std::unordered_map> output_edge_map; - int count = 0; float total_num_nodes_in_segments = 0.; - for (auto s : segments) { - total_num_nodes_in_segments += s.first.size(); - } - // We create the map here since cluster may not be available in all cases. - std::map name_to_device_map; - if (cluster) { - // TODO(aaroey): consider using DeviceSet::FindDeviceByName(), as in a - // distributed environment, devices from different workers can have same - // short name. - for (const auto dm : cluster->GetDeviceSet()->devices()) { - name_to_device_map[dm->name()] = dm; + std::vector engine_segments; + engine_segments.reserve(initial_segments.size()); + std::vector reverse_topo_order; + tensorflow::GetPostOrder(graph, &reverse_topo_order); + size_t total_engine_bytes_size = 0; + std::vector engine_bytes_size; + tensorflow::tensorrt::segment::SegmentNodesVector converted_segments; + converted_segments.reserve(initial_segments.size()); + for (size_t t = 0; t < initial_segments.size(); t++) { + auto& curr_segment = initial_segments.at(t); + EngineInfo curr_engine; + Status status = + GetEngineInfo(&graph, *params.graph_properties, curr_segment.first, + node_map, reverse_topo_order, &curr_engine); + if (!status.ok()) { + LOG(WARNING) << "Failed to get engine info for segment " << t << ": " + << status; + continue; } - } - for (const auto& segment_nodes_and_device : segments) { - const std::set& subgraph_node_names = - segment_nodes_and_device.first; - std::set subgraph_node_ids; - size_t max_mem_per_engine = - max_workspace_size_bytes * - ((float)subgraph_node_names.size() / total_num_nodes_in_segments); - std::stringstream oss; - for (const string& node_name : subgraph_node_names) { - oss << " " << node_name; - subgraph_node_ids.insert(node_map.at(node_name)->id()); + curr_engine.precision_mode = params.precision_mode; + curr_engine.engine_type = + (params.is_dyn_op || params.precision_mode == INT8MODE + ? EngineInfo::EngineType::TRTDynamic + : EngineInfo::EngineType::TRTStatic); + curr_engine.cached_engine_batches = params.cached_engine_batches; + curr_engine.maximum_cached_engines = params.max_cached_engines; + StrAppend(&curr_engine.engine_name, "my_trt_op_", t); + status = RegisterSegmentFunctionToFunctionLibrary( + &graph, curr_engine.segment_graph_def, curr_engine.engine_name); + if (!status.ok()) { + LOG(WARNING) << "Failed to register segment graphdef as a function " << t + << ": " << status; + continue; } - VLOG(1) << "Subgraph nodes at device " << segment_nodes_and_device.second - << " : " << oss.str(); - auto target_device = - name_to_device_map.find(segment_nodes_and_device.second); - std::shared_ptr allocator(0); + engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong()); + total_engine_bytes_size += engine_bytes_size.back(); + total_num_nodes_in_segments += curr_segment.first.size(); + engine_segments.push_back(std::move(curr_engine)); + converted_segments.push_back(std::move(curr_segment)); + + if (VLOG_IS_ON(8)) { + string fname = curr_engine.engine_name; + StrAppend(&fname, ".pb"); + std::fstream f; + f.open(fname.c_str(), std::fstream::out | std::fstream::binary); + f << engine_segments.at(t).segment_graph_def.SerializeAsString(); + f.close(); + } + } + + // Create a TRT node for each segment using its EngineInfo. + int old_cuda_device = 0; + auto err = cudaGetDevice(&old_cuda_device); + if (err != cudaSuccess) { + LOG(ERROR) << "Couldn't get current device: " << cudaGetErrorString(err); + } + VLOG(1) << "Current cuda device is " << old_cuda_device; + for (int i = 0; i < engine_segments.size(); ++i) { + auto& engine = engine_segments.at(i); + // Partition the workspace size by the average of node ratio and segment + // graphdef size + engine.max_workspace_size_bytes = + params.max_workspace_size_bytes * + (engine_bytes_size.at(i) / total_engine_bytes_size + + converted_segments.at(i).first.size() / total_num_nodes_in_segments) / + 2.0; + // The allocator is used to build the engine. The build and the built engine + // will be destroyed after we get the serialized engine string, so it's fine + // to use unique_ptr here. + std::unique_ptr alloc; + auto device_alloc = GetDeviceAndAllocator(params, engine); int cuda_device_id = 0; - if (target_device != name_to_device_map.end()) { - tensorflow::TfGpuId tf_gpu_id(target_device->second->parsed_name().id); - CudaGpuId cuda_gpu_id; - Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id); - if (!s.ok()) { - LOG(ERROR) - << "Cuda device identification failed, using device 0. Error= " - << s; - } else { - cuda_device_id = cuda_gpu_id.value(); - } - tensorflow::GPUOptions gpuoptions; - // we need to us PM here since in python path there is no way to get to - // allocators - auto pm = tensorflow::ProcessState::singleton(); - // this should be instantiated by now - auto dev_allocator = pm->GetGPUAllocator(gpuoptions, tf_gpu_id, 1); - VLOG(1) << "Got an allocator for device tf_device=" << tf_gpu_id.value() - << " cuda device= " << cuda_device_id << " at " << dev_allocator; - allocator = std::make_shared(dev_allocator); - } else { // device unknown or not available - allocator = std::make_shared(); + if (device_alloc.first >= 0) { + cuda_device_id = device_alloc.first; + alloc.reset(new TRTDeviceAllocator(device_alloc.second)); + } else { + // Setting allocator as nullptr should get revert to the cudamalloc + LOG(WARNING) << "Can't identify the cuda device. Running on device 0 "; } - ConvertGraphParams p(graph, output_names, subgraph_node_ids, max_batch_size, - max_mem_per_engine, graph_properties, &output_edge_map, - precision_mode, segment_nodes_and_device.second, - allocator, cuda_device_id); - if (precision_mode == INT8MODE) { - tensorflow::Status status = GetCalibNode(&p); - if (status != tensorflow::Status::OK()) { - LOG(WARNING) << "subgraph conversion error for subgraph_index:" << count - << " due to: \"" << status.ToString() - << "\" SKIPPING......( " << subgraph_node_names.size() - << " nodes)"; + cudaSetDevice(cuda_device_id); + auto status = CreateTRTNode(&graph, engine_segments, i, alloc.get(), + params.max_batch_size); + // If status is ok, we successfully added the node to the graph and can + // remove segment ops. Otherwise graph is not modified. + if (status.ok()) { + for (auto node_name : converted_segments.at(i).first) { + graph.RemoveNode(node_map.at(node_name)); } } else { - tensorflow::Status status = ConvertSubGraphToTensorRT(&p); - if (status != tensorflow::Status::OK()) { - LOG(WARNING) << "subgraph conversion error for subgraph_index:" << count - << " due to: \"" << status.ToString() - << "\" SKIPPING......( " << subgraph_node_names.size() - << " nodes)"; - } + // Graph is not modified. + LOG(WARNING) << "Engine creation for segment " << i << ", composed of " + << converted_segments.at(i).first.size() << " nodes failed: " + << status << ". Skipping..."; } - count++; } - graph.ToGraphDef(new_graph_def); + cudaSetDevice(old_cuda_device); + graph.ToGraphDef(params.output_graph_def); + VLOG(1) << "Returning from conversion"; return tensorflow::Status::OK(); } diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h index 65a67d7e73e32f904bd636a4f4aaefe32b0c092d..9d986e489043c0a0e16e379166aa2e8f7ac0b11f 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.h +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h @@ -30,29 +30,60 @@ namespace tensorflow { namespace tensorrt { namespace convert { -// This method converts an already generated calibration graph which was used in -// calibration runs to an inference graph +struct ConversionParams { + ConversionParams() + : input_graph_def(nullptr), + max_batch_size(1), + max_workspace_size_bytes(1 << 30), + output_graph_def(nullptr), + precision_mode(1), + minimum_segment_size(3), + graph_properties(nullptr), + cluster(nullptr), + is_dyn_op(false), + fixed_input_size(true), + max_cached_engines(1) {} + const tensorflow::GraphDef* input_graph_def; + const std::vector* output_names; + size_t max_batch_size; + size_t max_workspace_size_bytes; + tensorflow::GraphDef* output_graph_def; + int precision_mode; + int minimum_segment_size; + const tensorflow::grappler::GraphProperties* graph_properties; + const tensorflow::grappler::Cluster* cluster; + bool is_dyn_op; // Whether to create engine on conversion or execution time + bool fixed_input_size; // Assume non-batch ranks of input tensors are fixed + int max_cached_engines; // maximum number of cached engines + std::vector cached_engine_batches; // list of cached engines +}; + +// This method extracts calibration information from the resource managers +// and puts them in to engine nodedefs. tensorflow::Status ConvertCalibGraphToInferGraph( - const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def); + const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def, + bool is_dyn_op); -// max_batch_size: maximum batch size which can be used for inference for -// optimization targets inference run with max batch size. -// max_workspace_size_bytes: The upper bound of memory allowance for -// engine building. +// - max_batch_size: maximum batch size which can be used for inference for +// optimization targets inference run with max batch size. +// - max_workspace_size_bytes: The upper bound of memory allowance for engine +// building. tensorflow::Status ConvertGraphDefToTensorRT( const tensorflow::GraphDef& graph_def, const std::vector& output_names, size_t max_batch_size, size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def, - int precision_mode, int minimum_segment_size); + int precision_mode = 1, int minimum_segment_size = 3, + bool is_dyn_op = false, int max_cached_engines = 1, + std::vector cached_engine_batches = {}); // Method to call from optimization pass -tensorflow::Status ConvertAfterShapes( - const tensorflow::GraphDef& graph, const std::vector& output_names, - size_t max_batch_size, size_t max_workspace_size_bytes, - tensorflow::GraphDef* new_graph_def, int precision_mode, - int minimum_segment_size, - const tensorflow::grappler::GraphProperties& graph_properties, - const tensorflow::grappler::Cluster* cluster); +tensorflow::Status ConvertAfterShapes(ConversionParams& params); + +// Return compile time TensorRT library version information. +std::vector GetLinkedTensorRTVersion(); + +// Return runtime time TensorRT library version information. +std::vector GetLoadedTensorRTVersion(); } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 4e4d295538edadd26a347a38ec141737f097f26f..146b9c7344b0a9c2b3ec87b395e9b1096dbef06c 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include #include @@ -25,7 +24,9 @@ limitations under the License. #include #include +#include "tensorflow/contrib/tensorrt/convert/utils.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" #include "tensorflow/contrib/tensorrt/resources/trt_resources.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT @@ -37,6 +38,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -54,8 +56,11 @@ limitations under the License. namespace tensorflow { namespace tensorrt { namespace convert { +using ::tensorflow::str_util::Split; + using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; + namespace { inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype, @@ -121,12 +126,10 @@ static std::vector> CreateSamePadding( string GetCommonNameScope(const string& op_name_a, const string& op_name_b) { size_t last_scope_separator = 0; - for (size_t i = 0; i < std::min(op_name_a.size(), op_name_b.size()); ++i) { - if (op_name_a[i] != op_name_b[i]) { - break; - } else if (op_name_a[i] == '/') { - last_scope_separator = i + 1; - } + const size_t min_size = std::min(op_name_a.size(), op_name_b.size()); + for (size_t i = 0; i < min_size; ++i) { + if (op_name_a[i] != op_name_b[i]) break; + if (op_name_a[i] == '/') last_scope_separator = i + 1; } return op_name_a.substr(0, last_scope_separator); } @@ -417,20 +420,6 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights, } } -struct InferDeleter { - template - void operator()(T* obj) const { - if (obj) { - obj->destroy(); - } - } -}; - -template -inline std::shared_ptr infer_object(T* obj) { - return std::shared_ptr(obj, InferDeleter()); -} - class Converter; using OpConverter = @@ -444,7 +433,7 @@ class Converter { OpConverter plugin_converter_; nvinfer1::INetworkDefinition* trt_network_; std::list> temp_bufs_; - tensorflow::tensorrt::TRTWeightStore* weight_store_; + TRTWeightStore* weight_store_; bool fp16_; void register_op_converters(); tensorflow::Status get_inputs(const tensorflow::NodeDef& node_def, @@ -486,11 +475,11 @@ class Converter { public: explicit Converter(nvinfer1::INetworkDefinition* trt_network, - tensorflow::tensorrt::TRTWeightStore* ws, bool fp16) + TRTWeightStore* ws, bool fp16) : trt_network_(trt_network), weight_store_(ws), fp16_(fp16) { this->register_op_converters(); } - tensorflow::tensorrt::TRTWeightStore* weight_store() { return weight_store_; } + TRTWeightStore* weight_store() { return weight_store_; } TRT_ShapedWeights get_temp_weights(tensorflow::DataType type, nvinfer1::Dims shape) { TRT_ShapedWeights weights(type, nullptr, shape); @@ -2140,559 +2129,265 @@ void Converter::register_op_converters() { } // namespace -tensorflow::Status ConvertCalibrationNodeToEngineNode( - tensorflow::Graph& graph, tensorflow::Node* c_node) { - const auto ndef = c_node->def(); - - TFAttrs attrs(ndef); - std::vector segment_nodes( - attrs.get>("segment_nodes")); - std::vector output_nodes( - attrs.get>("segment_output_names")); - std::vector input_names( - attrs.get>("input_names")); - string res_name = attrs.get("resource_name"); - VLOG(1) << "Node name " << c_node->name() << " res_name " << res_name; - string engine_name = "my_trt_op"; - { - const auto node_id = tensorflow::str_util::Split(res_name, "_"); - engine_name += node_id.back(); - } - std::map node_maps; - - for (auto n : graph.op_nodes()) { - node_maps.insert({n->name(), n}); - } - std::set subgraph_ids; - for (const auto internal_node : segment_nodes) { - subgraph_ids.insert(node_maps.at(internal_node)->id()); - } - if (VLOG_IS_ON(2)) { - string node_names = StrCat(c_node->name(), " segment nodes= "); - - for (const auto& node_name : segment_nodes) { - StrAppend(&node_names, node_name, ", "); - } - VLOG(2) << node_names; +tensorflow::Status ConvertGraphDefToEngine( + const tensorflow::GraphDef& gdef, int precision_mode, int max_batch_size, + size_t max_workspace_size_bytes, + const std::vector& input_shapes, + Logger* logger, nvinfer1::IGpuAllocator* allocator, + TRTInt8Calibrator* calibrator, + TrtUniquePtrType* engine, + bool* convert_successfully) { + engine->reset(); + if (convert_successfully) *convert_successfully = false; + + // Create the builder. + TrtUniquePtrType builder( + nvinfer1::createInferBuilder(*logger)); + builder->setMaxBatchSize(max_batch_size); + // TODO(aaroey): use the allocator to allocate the TRT workspace. + builder->setMaxWorkspaceSize(max_workspace_size_bytes); +#if NV_TENSORRT_MAJOR > 3 + builder->setGpuAllocator(allocator); +#endif + if (precision_mode == FP16MODE) { + builder->setHalf2Mode(true); + } else if (precision_mode == INT8MODE) { + builder->setInt8Mode(true); + builder->setInt8Calibrator(calibrator); } - VLOG(1) << "Output Nodes:"; - std::vector out_types; - std::vector out_edges; + // Create the network. + auto trt_network = + TrtUniquePtrType(builder->createNetwork()); + if (!trt_network) { + return tensorflow::errors::Internal( + "Failed to create TensorRT network object"); + } + auto ws = std::unique_ptr(new TRTWeightStore()); - for (auto& i : output_nodes) { - auto node_port = tensorflow::str_util::Split(i, ":"); - VLOG(1) << " " << i << " in graph " << node_maps.count(i); - auto out_node_name = node_port.at(0); - if (node_port.size() > 1) { - VLOG(1) << "Multi port output" << node_port.at(0) << " " - << node_port.at(1) << " size=" << node_port.size(); - } - auto node_it = node_maps.find(out_node_name); - if (node_it != node_maps.end()) { - tensorflow::Node* out_node = node_it->second; - int port = 0; - if (node_port.size() == 2) { - port = std::strtoul(node_port.at(1).c_str(), nullptr, 10); - out_types.push_back(out_node->output_type(port)); - } else { - out_types.push_back(out_node->output_type(0)); + // Build the network + VLOG(1) << "Starting engine conversion "; + Converter converter(trt_network.get(), ws.get(), precision_mode == FP16MODE); + std::vector> output_tensors; + // Graph nodes are already topologically sorted during construction + for (const auto& node_def : gdef.node()) { + string node_name = node_def.name(); + VLOG(1) << "Converting op name=" << node_name << ", op=" << node_def.op(); + if (tensorflow::str_util::StartsWith(node_name, kInputPHName) && + (node_def.op() == "Placeholder")) { + nvinfer1::DimsCHW input_dim_pseudo_chw; + for (int i = 0; i < 8; i++) input_dim_pseudo_chw.d[i] = 0; + nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT); + auto type_status = + ConvertDType(node_def.attr().at("dtype").type(), &dtype); + if (type_status != tensorflow::Status::OK()) { + LOG(WARNING) << "Type conversion failed for " << node_name; + return type_status; } - for (auto out_edge : out_node->out_edges()) { - if (subgraph_ids.count(out_edge->dst()->id())) - continue; // skip internal edges; - if (out_edge->src_output() == port) { - out_edges.push_back(out_edge); - VLOG(1) << "OUTPUT EDGE " << out_edge->src()->name() << ":" - << out_edge->src_output() << " -> " << out_edge->dst()->name() - << ":" << out_edge->dst_input(); + int32 slot_number = -1; + if (!tensorflow::strings::safe_strto32(node_name.c_str() + 8, + &slot_number)) { + LOG(ERROR) << "Failed to parse slot number from " << node_name + << " +8= " << node_name.c_str() + 8; + } + auto shape = input_shapes.at(slot_number); + if (shape.dims() > 8) { + LOG(ERROR) << "Tensor rank is greater than 8 for " << node_name + << " at input slot " << slot_number; + return tensorflow::errors::OutOfRange( + "Input tensor rank is greater than 8"); + } + if (VLOG_IS_ON(1)) { + string dim_str("dims="); + StrAppend(&dim_str, "[ ", shape.dim_size(0)); + for (int i = 1; i < shape.dims(); i++) { + StrAppend(&dim_str, ", ", shape.dim_size(i)); } + StrAppend(&dim_str, " ]"); + VLOG(1) << dim_str; + } + for (int i = 1; i < shape.dims(); i++) { + input_dim_pseudo_chw.d[i - 1] = shape.dim_size(i); } - } else { - LOG(WARNING) << " couldn't find output node " << out_node_name; - } - } - if (VLOG_IS_ON(1)) { - VLOG(1) << c_node->name() << " Input Nodes:"; - for (auto& i : input_names) { - VLOG(1) << " Input " << i << " in graph " << node_maps.count(i); - } - } - auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance(); - auto resmgr = trt_rm->getManager("TRTCalibOps"); - tensorflow::tensorrt::TRTCalibrationResource* calib_res = nullptr; - auto status = resmgr->Lookup(res_name, res_name, &calib_res); - if (!status.ok() || !calib_res->calibrator_) { - return tensorflow::errors::FailedPrecondition( - "You must run calibration" - " and inference conversion in the same process"); - } - - calib_res->calibrator_->setDone(); - calib_res->thr_->join(); - delete calib_res->thr_; - if (!calib_res->engine_) { - LOG(ERROR) << "Calibration failed!, engine does not exist. Did you run " - "calibration graph?"; - return tensorflow::errors::FailedPrecondition( - "Calibration graph needs to be executed on" - " calibration data before convertsion to inference graph"); - } - auto weight_rmgr = trt_rm->getManager("WeightStore"); - TF_CHECK_OK(weight_rmgr->Delete( - res_name, res_name)); - auto engine_plan = calib_res->engine_->serialize(); - calib_res->engine_->destroy(); - calib_res->network_->destroy(); - calib_res->builder_->destroy(); - calib_res->thr_ = nullptr; - calib_res->engine_ = nullptr; - calib_res->builder_ = nullptr; - tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp"); - std::vector income_edges; - income_edges.resize(c_node->num_inputs()); - for (const auto in_edge : c_node->in_edges()) { - auto src = in_edge->src(); - int dest_port = in_edge->dst_input(); - VLOG(1) << "Incoming connection " << src->name() << ":" - << in_edge->src_output() << " -> " << c_node->name() << ":" - << dest_port; - income_edges.at(dest_port) = {src->name(), in_edge->src_output(), - c_node->input_type(dest_port)}; - } - tensorflow::gtl::ArraySlice input_list( - income_edges); - if (VLOG_IS_ON(2)) { - for (const auto& inp : input_list) { - VLOG(2) << " Input from inputlist " << inp.node << ":" << inp.index << " " - << tensorflow::DataTypeString(inp.data_type); - } - } - op_builder.Input(input_list); - tensorflow::NodeDef engine_node; - const char* engine_plan_data = static_cast(engine_plan->data()); - string engine_plan_string(engine_plan_data, - engine_plan_data + engine_plan->size()); - status = op_builder.Attr("serialized_engine", engine_plan_string) - .Attr("input_nodes", input_names) - .Attr("output_nodes", output_nodes) - .Attr("OutT", out_types) - .Finalize(&engine_node); - if (!status.ok()) { - LOG(ERROR) << "Engine Node creation failed"; - return status; - } - auto trt_engine_node = graph.AddNode(engine_node, &status); - TF_RETURN_IF_ERROR(status); - std::map port_map; - for (size_t t = 0; t < output_nodes.size(); t++) { - port_map.insert({output_nodes.at(t), t}); - } - for (auto& i : out_edges) { - string s(i->src()->name()); - if (i->src_output()) StrAppend(&s, ":", i->src_output()); - int out_port = port_map.at(s); - VLOG(1) << "Connecting " << trt_engine_node->name() << ":" << out_port - << " -> " << i->dst()->name() << ":" << i->dst_input(); - TF_RETURN_IF_ERROR( - graph.UpdateEdge(trt_engine_node, out_port, i->dst(), i->dst_input())); - } - for (const auto ed : trt_engine_node->in_edges()) { - VLOG(1) << "In Edge " << ed->src()->name() << ":" << ed->src_output() - << " -> " << ed->dst()->name() << ":" << ed->dst_input(); - } - for (const auto ed : trt_engine_node->out_edges()) { - VLOG(1) << "Out Edge " << ed->src()->name() << ":" << ed->src_output() - << " -> " << ed->dst()->name() << ":" << ed->dst_input(); - } - VLOG(1) << "Segment nodes:"; - for (auto& i : segment_nodes) { - VLOG(1) << " " << i << " in graph " << node_maps.count(i); - auto it = node_maps.find(i); - if (it != node_maps.end()) { - graph.RemoveNode(it->second); - } - } - graph.RemoveNode(c_node); - return tensorflow::Status::OK(); -} -tensorflow::Status ReverseTopologicalSort( - const tensorrt::convert::SubGraphParams& s, - std::list* order) { - std::vector order_vec; - tensorflow::GetPostOrder(s.graph, &order_vec); - // Select just the subgraph - for (tensorflow::Node* node : order_vec) { - if (s.subgraph_node_ids.count(node->id())) { - // We want topological order to contstruct the - // network layer by layer - order->push_front(node); + input_dim_pseudo_chw.nbDims = shape.dims() - 1; + nvinfer1::ITensor* input_tensor = converter.network()->addInput( + node_name.c_str(), dtype, input_dim_pseudo_chw); + if (!input_tensor) { + return tensorflow::errors::InvalidArgument( + "Failed to create Input layer tensor ", node_name, + " rank=", shape.dims() - 1); + } + VLOG(1) << "Input tensor name :" << node_name; + if (!converter.insert_input_tensor(node_name, input_tensor)) { + return tensorflow::errors::AlreadyExists( + "Output tensor already exists for op: " + node_name); + } + } else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) && + (node_def.op() == "Identity")) { + int32 slot_number = -1; + if (!tensorflow::strings::safe_strto32(node_name.c_str() + 9, + &slot_number)) { + LOG(ERROR) << "Failed to parse slot number from " << node_name + << " +9=" << node_name.c_str() + 9; + } + if (output_tensors.size() <= slot_number) { + output_tensors.resize(slot_number + 1); + } + output_tensors.at(slot_number) = {node_def.input(0), node_name}; + } else { + VLOG(2) << "Converting node: " << node_def.name() << " , " + << node_def.op(); + TF_RETURN_IF_ERROR(converter.convert_node(node_def)); } } - return tensorflow::Status::OK(); -} - -tensorflow::Status SetInputList( - const tensorrt::convert::SubGraphParams& s, - tensorflow::NodeDefBuilder* op_builder, - const std::vector* input_names, - std::vector* input_dtypes) { - std::vector income_edges; - VLOG(2) << "input edge size: " << input_names->size(); - for (size_t i = 0; i < input_names->size(); ++i) { - VLOG(2) << "input edges: " << i << " " << input_names->at(i); - int output_idx = s.input_inds.at(i).second; - // we wired up the input here already, it is redundant to do it again in - // ConvertSubGraphToTensorRT(convert_graph.cc) - auto incoming_edge = tensorflow::NodeDefBuilder::NodeOut( - input_names->at(i), output_idx, input_dtypes->at(i)); - income_edges.push_back(incoming_edge); - } - tensorflow::gtl::ArraySlice input_list( - income_edges); - op_builder->Input(input_list); - return tensorflow::Status::OK(); -} - -string SubgraphNameScopeGenerator(const std::list* order) { - string subgraph_name_scope; - if (!order->empty()) { - subgraph_name_scope = order->front()->name(); - } - for (const tensorflow::Node* node : *order) { - subgraph_name_scope = GetCommonNameScope(subgraph_name_scope, node->name()); - } - // TODO(sami,ben,jie): proper naming! - return subgraph_name_scope; -} - -tensorflow::Status ConvertSubgraph( - Converter& converter, tensorrt::convert::SubGraphParams& s, - std::list* order, std::vector* input_names, - std::vector* input_dtypes, - std::vector* output_names, - std::vector* output_dtypes, - const string& engine_name) { - std::set added_tensors; - for (const std::pair& input : s.input_inds) { - VLOG(2) << "parsing input. Node id= " << input.first; - int node_id = input.first; - int output_idx = input.second; - tensorflow::Node* node = s.graph.FindNodeId(node_id); - auto node_name = node->name(); - // input_names should use the node name in the graph - // here it should be the input tensor name -> matching the binding - // insert original node name without port - auto tensor_name = node_name; - if (output_idx != 0) { - tensor_name = StrCat(tensor_name, ":", output_idx); - } - - VLOG(2) << "input name: " << node_name << " tensor_name: " << tensor_name - << " idx: " << output_idx; - - auto shape_inference_node_name = node_name; - auto shape_inference_output_idx = output_idx; - // rewire the shape inference to original node in the graph - if (s.output_edge_map->count(tensor_name)) { - shape_inference_node_name = s.output_edge_map->at(tensor_name).second; - shape_inference_output_idx = s.output_edge_map->at(tensor_name).first; - } - if (shape_inference_output_idx < 0) continue; - VLOG(2) << "shapeinference name: " << shape_inference_node_name - << " idx: " << shape_inference_output_idx; - - if (!s.graph_properties.HasOutputProperties(shape_inference_node_name)) - return tensorflow::errors::Internal("failed to find input node: " + - shape_inference_node_name); - - auto op_info_vec = - s.graph_properties.GetOutputProperties(shape_inference_node_name); - if (static_cast(op_info_vec.size()) <= shape_inference_output_idx) - return tensorflow::errors::Internal( - "accessing output index of: ", shape_inference_output_idx, - ", at node: ", shape_inference_node_name, - " with output entry from shape_map: ", op_info_vec.size()); - - auto op_info = op_info_vec.at(shape_inference_output_idx); - tensorflow::DataType tf_dtype = op_info.dtype(); - - nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT); - auto type_status = ConvertDType(tf_dtype, &dtype); - if (type_status != tensorflow::Status::OK()) { - LOG(WARNING) << "Type conversion failed for " << node_name; - return type_status; - } - - VLOG(2) << "Accessing output index of: " << output_idx - << ", at node: " << node_name - << " with output entry from shape_map: " << op_info_vec.size(); - // TODO(ben,jie): update TRT input format/dimension - nvinfer1::DimsCHW input_dim_pseudo_chw; - for (int i = 0; i < 3; i++) input_dim_pseudo_chw.d[i] = 1; - - // TODO(jie): TRT 3.x only support 4 dimensional input tensor. - // update the code once TRT 4.0 comes out. - if (op_info.shape().dim_size() != 4) { - string err_str = "Require 4 dimensional input."; - StrAppend(&err_str, " Got ", op_info.shape().dim_size(), " ", - shape_inference_node_name); - return tensorflow::errors::Unimplemented(err_str); - } - - for (int i = 1; i < op_info.shape().dim_size(); i++) { - VLOG(2) << "dimension: " << i - << " , size: " << op_info.shape().dim(i).size(); - input_dim_pseudo_chw.d[i - 1] = op_info.shape().dim(i).size(); - } - - // TODO(ben,jie): proper way to restore input tensor name? - auto input_tensor_name = node_name; - if (output_idx != 0) { - input_tensor_name = StrCat(node_name, ":", output_idx); - } - if (added_tensors.count(input_tensor_name)) continue; - added_tensors.insert(input_tensor_name); - input_names->push_back(input_tensor_name); - input_dtypes->push_back(tf_dtype); - nvinfer1::ITensor* input_tensor = converter.network()->addInput( - input_tensor_name.c_str(), dtype, input_dim_pseudo_chw); - - if (!input_tensor) - return tensorflow::errors::InvalidArgument( - "Failed to create Input layer"); - VLOG(2) << "Input tensor name :" << input_tensor_name; - - if (!converter.insert_input_tensor(input_tensor_name, input_tensor)) - return tensorflow::errors::AlreadyExists( - "Output tensor already exists for op: " + input_tensor_name); - } - - for (const tensorflow::Node* node : *order) { - const tensorflow::NodeDef& node_def = node->def(); - VLOG(2) << "Converting node: " << node_def.name() << " , " << node_def.op(); - TF_RETURN_IF_ERROR(converter.convert_node(node_def)); - } - - VLOG(2) << "Finished conversion"; - - // Gather output metadata - int trt_engine_op_output_idx = 0; - added_tensors.clear(); - for (const std::pair& output : s.output_inds) { - int node_id = output.first; - int output_idx = output.second; - tensorflow::Node* node = s.graph.FindNodeId(node_id); - string op_name = node->name(); - string tensor_name = op_name; - - s.output_edge_map->insert( - {trt_engine_op_output_idx == 0 - ? engine_name - : StrCat(engine_name, ":", trt_engine_op_output_idx), - {output_idx, tensor_name}}); - trt_engine_op_output_idx++; - if (output_idx != 0) - tensorflow::strings::StrAppend(&tensor_name, ":", output_idx); - VLOG(2) << "Output tensor name: " << tensor_name; - if (added_tensors.count(tensor_name)) continue; - added_tensors.insert(tensor_name); - output_names->push_back(tensor_name); - auto tensor_or_weights = converter.get_tensor(tensor_name); + for (const auto& output : output_tensors) { + auto tensor_or_weights = converter.get_tensor(output.first); if (!tensor_or_weights.is_tensor()) { - return tensorflow::errors::InvalidArgument("Output node '" + tensor_name + - "' is weights not tensor"); + return tensorflow::errors::InvalidArgument( + "Output node '" + output.first + "' is weights not tensor"); } nvinfer1::ITensor* tensor = tensor_or_weights.tensor(); + tensor->setName(output.second.c_str()); if (!tensor) { return tensorflow::errors::NotFound("Output tensor not found: " + - tensor_name); + output.first); } + VLOG(1) << "Marking output tensor " << output.first << ", as output tensor " + << output.second; + converter.network()->markOutput(*tensor); - tensorflow::DataType tf_dtype = node->output_type(output_idx); - output_dtypes->push_back(tf_dtype); - nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT; - TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype)); - tensor->setType(trt_dtype); } + if (convert_successfully) *convert_successfully = true; - return tensorflow::Status::OK(); -} - -tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) { - // Visit nodes in reverse topological order and construct the TRT network. - // Toposort - std::list order; - TF_RETURN_IF_ERROR(ReverseTopologicalSort(s, &order)); - - static int static_id = 0; - string subgraph_name_scope = SubgraphNameScopeGenerator(&order); - // TODO(sami,ben,jie): proper naming! - string calib_op_name = - StrCat(subgraph_name_scope, "my_trt_calib_op_", static_id); - string engine_name = StrCat(subgraph_name_scope, "my_trt_op", static_id); - static_id++; - - auto trt_rmgr = tensorflow::tensorrt::TRTResourceManager::instance(); - auto op_rmgr = trt_rmgr->getManager("TRTCalibOps"); - auto op_res = new tensorflow::tensorrt::TRTCalibrationResource(); - TF_CHECK_OK(op_rmgr->Create(calib_op_name, calib_op_name, op_res)); - op_res->logger_ = new tensorflow::tensorrt::Logger(); - cudaSetDevice(s.cuda_gpu_id_); - op_res->builder_ = nvinfer1::createInferBuilder(*(op_res->logger_)); - op_res->allocator_ = s.allocator_; -#if NV_TENSORRT_MAJOR > 3 - op_res->builder_->setGpuAllocator(s.allocator_.get()); -#endif - if (!op_res->builder_) { - return tensorflow::errors::Internal( - "failed to create TensorRT builder object"); + // Build the engine. + VLOG(1) << "Starting engine creation"; + engine->reset(builder->buildCudaEngine(*converter.network())); + if (engine->get() == nullptr) { + return tensorflow::errors::Internal("Failed to build TensorRT engine"); } - - op_res->network_ = op_res->builder_->createNetwork(); - if (!op_res->network_) { - return tensorflow::errors::Internal( - "failed to create TensorRT network object"); - } - - // Build the network - auto weight_rmgr = trt_rmgr->getManager("WeightStore"); - auto ws = new tensorflow::tensorrt::TRTWeightStore(); - TF_CHECK_OK(weight_rmgr->Create(calib_op_name, calib_op_name, ws)); - Converter converter(op_res->network_, ws, s.precision_mode == FP16MODE); - - std::vector input_names; - std::vector input_dtypes; - std::vector output_names; - std::vector output_dtypes; - TF_RETURN_IF_ERROR(ConvertSubgraph(converter, s, &order, &input_names, - &input_dtypes, &output_names, - &output_dtypes, engine_name)); - - VLOG(2) << "Finished processing outputs"; - - // Build the engine - op_res->builder_->setMaxBatchSize(s.max_batch_size); - op_res->builder_->setMaxWorkspaceSize(s.max_workspace_size_bytes); - VLOG(0) << "Max batch size= " << s.max_batch_size - << " max workspace size= " << s.max_workspace_size_bytes; - - // Build the TRT op - // TODO(sami,ben,jie): proper naming! - tensorflow::NodeDefBuilder op_builder(calib_op_name, "TRTCalibOp"); - TF_RETURN_IF_ERROR(SetInputList(s, &op_builder, &input_names, &input_dtypes)); - - std::vector segment_names; - segment_names.reserve(s.subgraph_node_ids.size()); - for (int i : s.subgraph_node_ids) { - auto node = s.graph.FindNodeId(i); - segment_names.push_back(node->name()); - } - LOG(INFO) << "finished op preparation"; - - auto status = op_builder.Attr("segment_nodes", segment_names) - .Attr("input_names", input_names) - .Attr("segment_output_names", output_names) - .Attr("resource_name", calib_op_name) - .Finalize(s.trt_node); - - LOG(INFO) << status.ToString(); - LOG(INFO) << "finished op building"; - + VLOG(1) << "Finished conversion"; return tensorflow::Status::OK(); } -tensorflow::Status ConvertSubGraphToTensorRTNodeDef( - tensorrt::convert::SubGraphParams& s) { - // Visit nodes in reverse topological order and construct the TRT network. - std::list order; - TF_RETURN_IF_ERROR(ReverseTopologicalSort(s, &order)); - - static int static_id = 0; - string subgraph_name_scope = SubgraphNameScopeGenerator(&order); - string engine_name = StrCat(subgraph_name_scope, "my_trt_op", static_id++); - - tensorflow::tensorrt::Logger trt_logger; - cudaSetDevice(s.cuda_gpu_id_); - auto trt_builder = infer_object(nvinfer1::createInferBuilder(trt_logger)); - if (!trt_builder) { - return tensorflow::errors::Internal( - "Failed to create TensorRT builder object"); - } -#if NV_TENSORRT_MAJOR > 3 - trt_builder->setGpuAllocator(s.allocator_.get()); -#endif - auto trt_network = infer_object(trt_builder->createNetwork()); - if (!trt_network) { - return tensorflow::errors::Internal( - "Failed to create TensorRT network object"); - } - - auto trt_rmgr = tensorflow::tensorrt::TRTResourceManager::instance(); - auto weight_rmgr = trt_rmgr->getManager("WeightStore"); - auto ws = new tensorflow::tensorrt::TRTWeightStore(); - TF_CHECK_OK(weight_rmgr->Create(engine_name, engine_name, ws)); - - // Build the network - Converter converter(trt_network.get(), ws, s.precision_mode == FP16MODE); - - std::vector input_names; - std::vector input_dtypes; - std::vector output_names; - std::vector output_dtypes; - TF_RETURN_IF_ERROR(ConvertSubgraph(converter, s, &order, &input_names, - &input_dtypes, &output_names, - &output_dtypes, engine_name)); - - VLOG(2) << "Finished output"; - - // Build the engine - trt_builder->setMaxBatchSize(s.max_batch_size); - trt_builder->setMaxWorkspaceSize(s.max_workspace_size_bytes); - VLOG(0) << "Max batch size= " << s.max_batch_size - << " max workspace size= " << s.max_workspace_size_bytes; - if (s.precision_mode == FP16MODE) { - trt_builder->setHalf2Mode(true); - VLOG(0) << "Using FP16 precision mode"; - } - LOG(INFO) << "starting build engine"; - string engine_plan_string; - { - auto trt_engine = - infer_object(trt_builder->buildCudaEngine(*converter.network())); - VLOG(0) << "Built network"; - if (trt_engine.get() == nullptr) { - return tensorflow::errors::Internal("Engine building failure"); +tensorflow::Status ConvertSegmentToGraphDef( + const tensorflow::Graph* graph, + const tensorflow::grappler::GraphProperties& graph_properties, + const std::vector& subgraph_node_ids, // In topological order + std::vector* connections, + tensorflow::GraphDef* segment_def, string* common_scope) { + std::set marker_nodes; + // Update connection shapes/data types and add corresponding input/output + // nodes in the segment graphdef. + for (size_t i = 0; i < connections->size(); ++i) { + auto& connection = connections->at(i); + auto outside_node = graph->FindNodeId(connection.outside_id); + if (!outside_node) { + // This should never happen, unless the original graph is problematic. + return tensorflow::errors::NotFound( + "Cannot find node with id ", connection.outside_id, " in the graph."); + } + // Updates the shape and data types of input/output connections. + tensorflow::DataType input_type = tensorflow::DT_FLOAT; + tensorflow::PartialTensorShape partial_shape; + if (connection.is_input_edge) { + if (graph_properties.HasOutputProperties(connection.outside_node_name)) { + auto output_params = + graph_properties.GetOutputProperties(connection.outside_node_name); + auto out_shape = output_params.at(connection.outside_port); + input_type = out_shape.dtype(); + std::vector dims; + partial_shape = out_shape.shape(); + connection.outside_shape = partial_shape; + } else { + VLOG(0) << "Unknown output shape" << outside_node->name(); + input_type = graph->FindNodeId(connection.outside_id) + ->output_type(connection.outside_port); + } + connection.connection_type = input_type; + + } else { // output edge + if (graph_properties.HasInputProperties(connection.outside_node_name)) { + auto input_params = + graph_properties.GetInputProperties(connection.outside_node_name); + auto in_shape = input_params.at(connection.outside_port); + input_type = in_shape.dtype(); + partial_shape = in_shape.shape(); + connection.inside_shape = partial_shape; + } else { + input_type = graph->FindNodeId(connection.inside_id) + ->output_type(connection.outside_port); + } + connection.connection_type = input_type; } - auto engine_plan = infer_object(trt_engine->serialize()); - VLOG(0) << "Serialized engine"; - const char* engine_plan_data = - static_cast(engine_plan->data()); - engine_plan_string = - string(engine_plan_data, engine_plan_data + engine_plan->size()); - } - TF_RETURN_IF_ERROR(weight_rmgr->Delete( - engine_name, engine_name)); - LOG(INFO) << "finished engine " << engine_name << " containing " - << s.subgraph_node_ids.size() << " nodes"; - - // Build the TRT op - tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp"); - TF_RETURN_IF_ERROR(SetInputList(s, &op_builder, &input_names, &input_dtypes)); - - VLOG(0) << "Finished op preparation"; - - auto status = op_builder.Attr("serialized_engine", engine_plan_string) - .Attr("input_nodes", input_names) - .Attr("output_nodes", output_names) - .Attr("OutT", output_dtypes) - .Device(s.device_name_) - .Finalize(s.trt_node); - - VLOG(0) << status.ToString() << " finished op building for " << engine_name - << " on device " << s.device_name_; + // Add dummy input/output nodes to the segment graphdef. + if (connection.is_input_edge) { + const string node_name = StrCat(kInputPHName, connection.port_number); + if (marker_nodes.count(node_name)) { + VLOG(1) << "Reusing input " << node_name << " for the edge " + << connection.outside_node_name << ":" + << connection.outside_port << " -> " + << connection.inside_node_name << ":" << connection.inside_port; + continue; + } + marker_nodes.insert(node_name); + auto seg_node = segment_def->add_node(); + tensorflow::NodeDefBuilder builder(node_name, "Placeholder"); + auto status = builder.Attr("shape", partial_shape) + .Attr("dtype", input_type) + .Finalize(seg_node); + VLOG(1) << "Constructing input " << node_name << " for the edge " + << connection.outside_node_name << ":" << connection.outside_port + << " -> " << connection.inside_node_name << ":" + << connection.inside_port; + } else { + const string node_name = StrCat(kOutputPHName, connection.port_number); + if (marker_nodes.count(node_name)) { + VLOG(1) << "Reusing output " << node_name << " for the edge " + << connection.inside_node_name << ":" << connection.inside_port + << " -> " << connection.outside_node_name << ":" + << connection.outside_port; + continue; + } + marker_nodes.insert(node_name); + auto seg_node = segment_def->add_node(); + tensorflow::NodeDefBuilder builder(node_name, "Identity"); + auto status = builder.Input(connection.inside_node_name, 0, input_type) + .Finalize(seg_node); + VLOG(1) << "Constructing output " << node_name << " for the edge " + << connection.inside_node_name << ":" << connection.inside_port + << " -> " << connection.outside_node_name << ":" + << connection.outside_port; + } + } // for each connection. + + std::unordered_map old_to_new_id_map; + // Copy internal nodes to new graphdef + string local_scope = graph->FindNodeId(*subgraph_node_ids.begin())->name(); + for (const auto node_id : subgraph_node_ids) { + const auto node = graph->FindNodeId(node_id); + local_scope = GetCommonNameScope(local_scope, node->name()); + old_to_new_id_map[node_id] = segment_def->node_size(); + auto snode = segment_def->add_node(); + snode->CopyFrom(node->def()); + VLOG(1) << "Copying " << snode->name() << " to subgraph"; + } + // Update the inputs of the new input nodes to point to placeholder nodes. + for (int i = 0; i < connections->size(); ++i) { + auto& connection = connections->at(i); + if (!connection.is_input_edge) continue; + auto snode = + segment_def->mutable_node(old_to_new_id_map[connection.inside_id]); + const string placeholder_name = + StrCat(kInputPHName, connection.port_number); + VLOG(1) << "Updating " << snode->name() << ":" << connection.inside_port + << " from " << snode->input(connection.inside_port) << " to " + << placeholder_name; + snode->set_input(connection.inside_port, placeholder_name); + } + *common_scope = local_scope; + VLOG(0) << "Segment @scope '" << local_scope << "', converted to graph"; return tensorflow::Status::OK(); } diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h index 3f6592cd25ff013cadc0621ba64f0553983dd10b..7684d8d4a23ae22c855d82fc54931151a976eb2f 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h @@ -22,69 +22,112 @@ limitations under the License. #include #include +#include "tensorflow/contrib/tensorrt/convert/utils.h" #include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" +#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/lib/core/status.h" + #if GOOGLE_CUDA #if GOOGLE_TENSORRT namespace tensorflow { namespace tensorrt { +static const char* kInputPHName = "InputPH_"; +static const char* kOutputPHName = "OutputPH_"; namespace convert { +// TODO(aaroey): use an enum instead. const int FP32MODE = 0; const int FP16MODE = 1; const int INT8MODE = 2; -struct SubGraphParams { - SubGraphParams( - tensorflow::Graph& inp_graph, - const std::set& subgraph_node_id_numbers, - const std::vector>& input_indices, - const std::vector>& output_indices, - size_t max_supported_batch_size, size_t max_consumed_workspace_size_bytes, - const tensorflow::grappler::GraphProperties& current_graph_properties, - std::unordered_map>* output_edges, - tensorflow::NodeDef* constructed_trt_node, - int engine_precision_mode = FP32MODE, const string& device_name = "", - std::shared_ptr allocator = nullptr, - int cuda_gpu_id = 0) - : graph(inp_graph), - subgraph_node_ids(subgraph_node_id_numbers), - input_inds(input_indices), - output_inds(output_indices), - max_batch_size(max_supported_batch_size), - max_workspace_size_bytes(max_consumed_workspace_size_bytes), - graph_properties(current_graph_properties), - output_edge_map(output_edges), - trt_node(constructed_trt_node), - precision_mode(engine_precision_mode), - device_name_(device_name), - allocator_(allocator), - cuda_gpu_id_(cuda_gpu_id) {} - - tensorflow::Graph& graph; - const std::set& subgraph_node_ids; - const std::vector>& input_inds; // {node_id, output_idx} - const std::vector>& output_inds; // {node_id, output_idx} - size_t max_batch_size; - size_t max_workspace_size_bytes; - const tensorflow::grappler::GraphProperties& graph_properties; - std::unordered_map>* output_edge_map; - tensorflow::NodeDef* trt_node; - const int precision_mode; - const string device_name_; - std::shared_ptr allocator_; - const int cuda_gpu_id_; +struct EngineConnection { + EngineConnection(const string& outside, int out_id, int out_port, + const string& inside, int in_id, int in_port, + bool input_edge, int port) + : outside_node_name(outside), + outside_id(out_id), + outside_port(out_port), + inside_node_name(inside), + inside_id(in_id), + inside_port(in_port), + is_input_edge(input_edge), + port_number(port) {} + + const string outside_node_name; + const int outside_id; + const int outside_port; + tensorflow::PartialTensorShape outside_shape; + + const string inside_node_name; + const int inside_id; + const int inside_port; + tensorflow::PartialTensorShape inside_shape; + + tensorflow::DataType connection_type; + bool is_input_edge; + + // The port number of the TRT node connecting to this edge. + int port_number; +}; + +struct EngineInfo { + EngineInfo() + : engine_type(EngineType::TRTStatic), + max_workspace_size_bytes(0), + precision_mode(FP32MODE) {} + + string engine_name; + string device; + tensorflow::GraphDef segment_graph_def; + + // The segment nodes that are on one side of the edges are topological sorted. + std::vector connections; + + enum class EngineType { TRTStatic = 0, TRTDynamic = 1 }; + EngineType engine_type; + int64 max_workspace_size_bytes; + int maximum_cached_engines; + std::vector cached_engine_batches; + int precision_mode; }; -// TODO(sami): Replace references with const reference or pointers -tensorflow::Status ConvertSubGraphToTensorRTNodeDef(SubGraphParams& params); -tensorflow::Status InjectCalibrationNode(SubGraphParams& params); -tensorflow::Status ConvertCalibrationNodeToEngineNode(tensorflow::Graph& graph, - tensorflow::Node* c_node); +// Constructs a graphdef from the segment in the given graph. Adds placeholder +// nodes for input edges (InputPH_*) and identity nodes for output edges +// (OutputPH_*). This function needs to be called before TensorRT nodes +// inserted in order to correctly get sizes from the original graph. +// +// - subgraph_node_ids: the node ids of the subgraph, must be sorted in +// topological order. +// - segment_def: the output GraphDef, whose non-input/output nodedefs will be +// sorted in topological order. +tensorflow::Status ConvertSegmentToGraphDef( + const tensorflow::Graph* graph, + const tensorflow::grappler::GraphProperties& graph_properties, + const std::vector& subgraph_node_ids, + std::vector* connections, + tensorflow::GraphDef* segment_def, string* common_scope); + +// Converts given subgraph to a TRT engine saved in 'engine'. Returns ok iff +// 'builder' successfully build the engine. If the result is not ok, 'engine' +// will be set to nullptr +// Once returned, 'builder' is not needed any more and can be safely detroyed. +// +// - convert_successfully: indicates whether the converson to TensorRT network +// is successful. This is different than successfully building the engine: +// building can still fail afterwards. +tensorflow::Status ConvertGraphDefToEngine( + const tensorflow::GraphDef& gdef, int precision_mode, int max_batch_size, + size_t max_workspace_size_bytes, + const std::vector& input_shapes, + Logger* logger, nvinfer1::IGpuAllocator* allocator, + TRTInt8Calibrator* calibrator, + TrtUniquePtrType* engine, + bool* convert_successfully); + } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc index 8f634b1f74717310a69a6bab5d5224c9bdbf10cc..ec9dbfa13bfd0a158dcf41cf1fdb7128a2adf641 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc @@ -45,8 +45,24 @@ tensorflow::Status TRTOptimizationPass::Init( if (params.count("max_batch_size")) { maximum_batch_size_ = params.at("max_batch_size").i(); } - if (params.count("max_workspace_size_bytes")) + is_dynamic_op_ = false; + if (params.count("is_dynamic_op")) { + is_dynamic_op_ = params.at("is_dynamic_op").b(); + } + if (params.count("cached_engine_batches")) { + auto batch_vec = params.at("cached_engine_batches").list(); + batches_.reserve(batch_vec.i_size()); + for (const auto i : batch_vec.i()) { + batches_.push_back(i); + } + } + max_cached_batches_ = 1; + if (params.count("maximum_cached_engines")) { + max_cached_batches_ = params.at("maximum_cached_engines").i(); + } + if (params.count("max_workspace_size_bytes")) { maximum_workspace_size_ = params.at("max_workspace_size_bytes").i(); + } if (params.count("precision_mode")) { string pm = Uppercase(params.at("precision_mode").s()); if (pm == "FP32") { @@ -175,6 +191,17 @@ tensorflow::Status TRTOptimizationPass::Optimize( if (VLOG_IS_ON(1)) { PrintDebugInfo(cluster, item); } + // This is a hack to workaround optimizer issue. MetaOptimizer calls + // optimization passes on function objects as well, we should not modify + // generated funcdefs! This is fragile but we don't have any other option + // until framework fixes it. + if (item.id != "tf_graph") { + LOG(WARNING) << name_ + << " is probably called on funcdef! This optimizer must *NOT* " + "be called on function objects."; + *optimized_graph = item.graph; + return tensorflow::Status::OK(); + } int max_dim = -1; if (item.feed.size()) { for (const auto& f : item.feed) { @@ -204,11 +231,22 @@ tensorflow::Status TRTOptimizationPass::Optimize( } tensorflow::grappler::GraphProperties static_graph_properties(item); TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true)); - auto status = tensorflow::tensorrt::convert::ConvertAfterShapes( - item.graph, item.fetch, maximum_batch_size_, maximum_workspace_size_, - optimized_graph, precision_mode_, minimum_segment_size_, - static_graph_properties, cluster); + tensorflow::tensorrt::convert::ConversionParams cp; + cp.input_graph_def = &item.graph; + cp.output_names = &item.fetch; + cp.max_batch_size = maximum_batch_size_; + cp.max_workspace_size_bytes = maximum_workspace_size_; + cp.output_graph_def = optimized_graph; + cp.precision_mode = precision_mode_; + cp.minimum_segment_size = minimum_segment_size_; + cp.graph_properties = &static_graph_properties; + cp.cluster = cluster; + cp.is_dyn_op = is_dynamic_op_; + cp.cached_engine_batches = batches_; + cp.max_cached_engines = max_cached_batches_; + auto status = tensorflow::tensorrt::convert::ConvertAfterShapes(cp); VLOG(2) << optimized_graph->DebugString(); + VLOG(1) << "Returning from " << name_; return status; } diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h index d8ecead23efaa5c3bab95b8ba481e2307b0af772..463ed3883e4808408104c618a289989472c497ea 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h @@ -61,6 +61,9 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer { int minimum_segment_size_; int precision_mode_; int maximum_batch_size_; + bool is_dynamic_op_; + std::vector batches_; + int max_cached_batches_; int64_t maximum_workspace_size_; }; diff --git a/tensorflow/compiler/xla/service/versioned_computation_handle.cc b/tensorflow/contrib/tensorrt/convert/utils.h similarity index 53% rename from tensorflow/compiler/xla/service/versioned_computation_handle.cc rename to tensorflow/contrib/tensorrt/convert/utils.h index a693c4695f0e776cf297d0ecd28d6de53bd5c0c6..f601c06701fdbf983b708cf5f5c7d22634bb810b 100644 --- a/tensorflow/compiler/xla/service/versioned_computation_handle.cc +++ b/tensorflow/contrib/tensorrt/convert/utils.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,20 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" +#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_ -#include "tensorflow/core/lib/strings/strcat.h" +#include -namespace xla { +namespace tensorflow { +namespace tensorrt { -string VersionedComputationHandle::ToString() const { - return tensorflow::strings::StrCat(handle.handle(), ":v", version); -} +template +struct TrtDestroyer { + void operator()(T* t) { + if (t) t->destroy(); + } +}; -std::ostream& operator<<(std::ostream& out, - const VersionedComputationHandle& versioned_handle) { - out << versioned_handle.ToString(); - return out; -} +template +using TrtUniquePtrType = std::unique_ptr>; -} // namespace xla +} // namespace tensorrt +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_ diff --git a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc deleted file mode 100644 index aea44fd8a2fcc4c359a6cb0c98ae34711708326e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc +++ /dev/null @@ -1,136 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/contrib/tensorrt/kernels/trt_calib_op.h" -#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" -#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/platform/stream_executor.h" - -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT -#include "cuda/include/cuda_runtime_api.h" -#include "tensorrt/include/NvInfer.h" - -namespace tensorflow { -namespace tensorrt { - -TRTCalibOp::TRTCalibOp(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("segment_nodes", &segment_nodes_)); - OP_REQUIRES_OK(context, context->GetAttr("input_names", &input_names_)); - OP_REQUIRES_OK(context, context->GetAttr("resource_name", &resource_name_)); -}; - -#define TYPECASE(dt, X, Y) \ - case dt: { \ - return (void*)X->flat::Type>().data(); \ - } - -void* GetTensorAddress(const Tensor* tensor_ptr) { - auto tensor_type = tensor_ptr->dtype(); - switch (tensor_type) { - TYPECASE(tensorflow::DT_FLOAT, tensor_ptr, dest_ptr); - TYPECASE(tensorflow::DT_HALF, tensor_ptr, dest_ptr); - TYPECASE(tensorflow::DT_INT8, tensor_ptr, dest_ptr); - default: { - LOG(FATAL) << "Unsupported Data type " - << tensorflow::DataTypeString(tensor_type); - return nullptr; - } - } -} - -void TRTCalibOp::Compute(tensorflow::OpKernelContext* ctx) { - // TODO(aaroey): make sure ctx->resource_mgr() is used in future PR. - auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance(); - auto res_mgr = trt_rm->getManager("TRTCalibOps"); - tensorflow::tensorrt::TRTCalibrationResource* calib_res = nullptr; - auto status = res_mgr->Lookup(resource_name_, resource_name_, &calib_res); - - if (!status.ok()) { - ctx->SetStatus(status); - return; - } - int num_inputs = ctx->num_inputs(); - // first run instantiate calibrator - if (calib_res->calibrator_ == nullptr) { - dev_tensors_.resize(num_inputs); - int batch_size = ctx->input(0).dim_size(0); - VLOG(1) << " Constructing calibrator"; - for (int i = 0; i < num_inputs; i++) { - // allocate workspace on device for inputs - const tensorflow::Tensor& t = ctx->input(i); - OP_REQUIRES_OK(ctx, - ctx->allocate_persistent(t.dtype(), t.shape(), - &dev_tensors_.at(i), nullptr)); - const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx); - CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes()); - void* device_address = GetTensorAddress(device_tensor); - device_buffers_.emplace(input_names_.at(i), - std::pair( - device_address, device_tensor->TotalBytes())); - } - - calib_res->calibrator_ = - new TRTInt8Calibrator(device_buffers_, batch_size, resource_name_); - string label(resource_name_); - calib_res->thr_ = new std::thread([calib_res, label]() { - VLOG(1) << "Starting calibration thread, Calibration Resource @ " - << calib_res; - calib_res->builder_->setInt8Calibrator(calib_res->calibrator_); - calib_res->builder_->setInt8Mode(true); - calib_res->engine_ = calib_res->builder_->buildCudaEngine( - *calib_res->network_); // will loop until we terminate calibrator - VLOG(1) << "Calibration loop terminated " << label; - }); - VLOG(1) << "initialized calibrator resource"; - } // calibrator initialized - - // Pass input data to calibrator - std::unordered_map input_data; - for (int i = 0; i < num_inputs; i++) { - const Tensor& t = ctx->input(i); - void* data_address = GetTensorAddress(&t); - const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx); - CHECK_EQ(t.TotalBytes(), - device_tensor->TotalBytes()); // use the tensor so FW keeps it - input_data.emplace(input_names_.at(i), data_address); - ctx->set_output(i, t); - } - VLOG(2) << "Filled map for sending"; - // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files - const cudaStream_t* stream = CHECK_NOTNULL( - reinterpret_cast(ctx->op_device_context() - ->stream() - ->implementation() - ->CudaStreamMemberHack())); - calib_res->calibrator_->setBatch(input_data, *stream); - VLOG(2) << "Passed calibration data"; - // TODO(aaroey): make sure we wait for the completion of calibration on the - // last batch in future PR. -}; - -#undef TYPECASE - -REGISTER_KERNEL_BUILDER(Name("TRTCalibOp").Device(DEVICE_GPU), TRTCalibOp); - -} // namespace tensorrt -} // namespace tensorflow -#endif -#endif diff --git a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h deleted file mode 100644 index 23df9db32f077a080eaff7479fcbe90d6a504c42..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_CALIB_OP_H -#define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_CALIB_OP_H - -#include -#include -#include -#include -#include -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/platform/types.h" - -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT -namespace tensorflow { -namespace tensorrt { -// TODO(sami): Convert this to async kernel! -class TRTCalibOp : public OpKernel { - public: - explicit TRTCalibOp(OpKernelConstruction* context); - - void Compute(OpKernelContext* context) override; - - private: - string resource_name_; - std::vector segment_nodes_; - std::vector input_names_; - std::vector shapes_; - std::unordered_map> device_buffers_; - std::vector dev_tensors_; -}; -} // namespace tensorrt -} // namespace tensorflow -#endif -#endif -#endif // TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_CALIB_OP_H diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 9ac8047944874181de228a6cc58e2dafe46abe50..75e32559bb055a49ccef2100c208c6277c0c4b60 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -14,8 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h" +#include +#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/contrib/tensorrt/convert/utils.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" +#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/types.h" @@ -25,144 +33,556 @@ limitations under the License. #include "cuda/include/cuda_runtime_api.h" namespace tensorflow { -static ::tensorflow::tensorrt::Logger logger; -using IRuntime = nvinfer1::IRuntime; -using Dims = nvinfer1::Dims; - namespace tensorrt { +static Logger logger; +using ::nvinfer1::IRuntime; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + +// A helper class to call done() when destructed for asynchronous execution. +// Helps simultaneous execution of native and TRT engines. +class AsyncHelper : public tensorflow::core::RefCounted { + public: + AsyncHelper(tensorflow::AsyncOpKernel::DoneCallback done) { done_ = done; } + ~AsyncHelper() override { done_(); } -TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) { + private: + tensorflow::AsyncOpKernel::DoneCallback done_; +}; + +#define TYPECASE(dt, X, Y) \ + case dt: { \ + return (void*)X->flat::Type>().data(); \ + } + +void* GetTensorAddress(const Tensor* tensor_ptr) { + auto tensor_type = tensor_ptr->dtype(); + switch (tensor_type) { + TYPECASE(tensorflow::DT_FLOAT, tensor_ptr, dest_ptr); + TYPECASE(tensorflow::DT_HALF, tensor_ptr, dest_ptr); + TYPECASE(tensorflow::DT_INT8, tensor_ptr, dest_ptr); + default: { + LOG(ERROR) << "Unsupported Data type " + << tensorflow::DataTypeString(tensor_type); + return nullptr; + } + } +} + +tensorflow::Status TRTEngineOp::ConstructFunctionHandle(OpKernelContext* ctx) { + VLOG(1) << "Constructing function handle"; + auto lib = ctx->function_library(); + if (lib == nullptr) { + return tensorflow::errors::Internal("Context function library is null"); + } + auto fdef = lib->GetFunctionLibraryDefinition()->Find(funcdef_name_); + if (fdef == nullptr) { + return tensorflow::errors::Internal("Native FunctionDef ", funcdef_name_, + " can't be found in function library"); + } + tensorflow::FunctionLibraryRuntime::InstantiateOptions inst_ops; + inst_ops.overlay_lib = nullptr; + inst_ops.state_handle = ""; + inst_ops.target = ctx->device()->name(); + native_func_ = 0; + auto status = lib->Instantiate(funcdef_name_, AttrSlice(&fdef->attr()), + inst_ops, &native_func_); + if (!status.ok()) { + LOG(ERROR) << " Instantiating native function " << funcdef_name_ + << " failed!"; + } + return status; +} + +TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) + : AsyncOpKernel(context) { // read serialized_engine OP_REQUIRES_OK(context, - context->GetAttr("serialized_engine", &serialized_engine_)); + context->GetAttr("serialized_segment", &serialized_segment_)); + OP_REQUIRES_OK(context, + context->GetAttr("workspace_size_bytes", &workspace_size_)); + OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_)); + if (!static_engine_) { + if (!segment_graph_.ParseFromString(serialized_segment_)) { + LOG(ERROR) << "Parsing segment graph failed!"; + context->SetStatus(tensorflow::errors::InvalidArgument( + "Failed to parse segment graphdef!")); + return; + } + serialized_segment_.resize(0); + } + VLOG(1) << "Constructing " << name(); + string precision_string; + OP_REQUIRES_OK(context, + context->GetAttr("precision_mode", &precision_string)); + string calibration_data; + OP_REQUIRES_OK(context, + context->GetAttr("calibration_data", &calibration_data)); + OP_REQUIRES_OK(context, + context->GetAttr("segment_funcdef_name", &funcdef_name_)); + if (precision_string == "FP32") { + precision_mode_ = convert::FP32MODE; + } else if (precision_string == "FP16") { + precision_mode_ = convert::FP16MODE; + } else if (precision_string == "INT8") { + precision_mode_ = convert::INT8MODE; + } + calibration_mode_ = + (precision_mode_ == convert::INT8MODE && calibration_data.size() == 0); + if (calibration_data.size()) { + calibrator_.reset(new TRTInt8Calibrator(calibration_data)); + calibration_data.resize(0); + } + native_func_ = tensorflow::kInvalidHandle; + OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count", + &max_cached_engines_)); + OP_REQUIRES_OK(context, + context->GetAttr("fixed_input_size", &fixed_input_size_)); + OP_REQUIRES_OK(context, context->GetAttr("cached_engine_batches", + &cached_engine_batches_)); + std::sort(cached_engine_batches_.begin(), cached_engine_batches_.end()); + if (VLOG_IS_ON(1)) { + string s("Engine Batches= "); + for (auto i : cached_engine_batches_) { + StrAppend(&s, i, " "); + } + VLOG(1) << s; + } +} - // register input output node name in trt_sub_graph - OP_REQUIRES_OK(context, context->GetAttr("input_nodes", &input_nodes_)); - OP_REQUIRES_OK(context, context->GetAttr("output_nodes", &output_nodes_)); +void TRTEngineOp::ExecuteNativeSegment(tensorflow::OpKernelContext* ctx, + AsyncHelper* helper) { + if (!calibration_mode_) { + VLOG(1) << "Executing native engine"; + } + std::vector inputs; + std::vector* outputs = new std::vector(); + if (native_func_ == tensorflow::kInvalidHandle) { + auto status = ConstructFunctionHandle(ctx); + if (!status.ok()) { + LOG(ERROR) << "Couldn't construct function handle " << funcdef_name_; + ctx->SetStatus(status); + return; + } + } + auto lib = ctx->function_library(); + tensorflow::FunctionLibraryRuntime::Options opts; + opts.step_id = ctx->step_id(); + opts.rendezvous = ctx->rendezvous(); + opts.cancellation_manager = ctx->cancellation_manager(); + opts.runner = ctx->runner(); + for (int i = 0; i < ctx->num_inputs(); i++) { + inputs.push_back(ctx->input(i)); + } + helper->Ref(); // Increment count for calculating native graph + VLOG(1) << "Executing native segment " << name(); + lib->Run(opts, native_func_, inputs, outputs, + [ctx, outputs, helper](const tensorflow::Status& s) { + tensorflow::core::ScopedUnref sc(helper); + VLOG(1) << "Native Segment completed"; + if (!s.ok()) { + ctx->SetStatus(s); + return; + } + for (size_t t = 0; t < outputs->size(); ++t) { + ctx->set_output(t, outputs->at(t)); + } + delete outputs; + }); } -void TRTEngineOp::Compute(OpKernelContext* context) { - // TODO(samikama) runtime should be taken from a resourcemanager as well. - // Only engine should be in the op and context and runtime should be taken - // from resourcemanager +void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx, + AsyncHelper* helper) { + helper->Ref(); + tensorflow::core::ScopedUnref sc(helper); + // TODO(aaroey): remove the ResourceMgr singleton. + auto trt_rm = TRTResourceManager::instance(); + auto res_mgr = trt_rm->getManager("TRTCalibration"); + TRTCalibrationResource* calib_res = nullptr; + auto status = res_mgr->LookupOrCreate( + funcdef_name_, "Calibrator", &calib_res, + {[ctx, this](TRTCalibrationResource** cr) -> tensorflow::Status { + return this->AllocateCalibrationResources(ctx, cr); + }}); + if (!status.ok()) { + ctx->SetStatus(status); + return; + } + int num_inputs = ctx->num_inputs(); + // Pass input data to calibrator + std::unordered_map input_data; + for (int i = 0; i < num_inputs; i++) { + const Tensor& t = ctx->input(i); + void* data_address = GetTensorAddress(&t); + if (data_address == nullptr) { + ctx->SetStatus(tensorflow::errors::InvalidArgument( + "Unsupported data type encountered in input ", i)); + return; + } + // Check the allocated buffer is sufficient for input + const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx); + CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes()); + input_data.emplace(StrCat(kInputPHName, i), data_address); + } + VLOG(2) << "Filled map for sending"; + // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files + const cudaStream_t* stream = CHECK_NOTNULL( + reinterpret_cast(ctx->op_device_context() + ->stream() + ->implementation() + ->CudaStreamMemberHack())); + calib_res->calibrator_->setBatch(input_data, *stream); + VLOG(2) << "Passed calibration data"; + ExecuteNativeSegment(ctx, helper); +} - if (!trt_execution_context_ptr_) { - IRuntime* infer = nvinfer1::createInferRuntime(logger); -#if NV_TENSORRT_MAJOR > 3 - auto device = context->device(); - auto dev_allocator = - device->GetAllocator(tensorflow::AllocatorAttributes()); - if (!dev_allocator) { - LOG(FATAL) << "Can't find device allocator for gpu device " - << device->name(); - } - allocator_ = std::make_shared(dev_allocator); - infer->setGpuAllocator(allocator_.get()); -#endif - trt_engine_ptr_.reset(infer->deserializeCudaEngine( - serialized_engine_.c_str(), serialized_engine_.size(), - PluginFactoryTensorRT::GetInstance())); - trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext()); - // Runtime is safe to delete after engine creation - infer->destroy(); - serialized_engine_.clear(); +int TRTEngineOp::GetEngineBatch(tensorflow::OpKernelContext* ctx) { + int num_batch = ctx->input(0).shape().dim_size(0); + int smallest_engine = 0; + for (const auto i : cached_engine_batches_) { + if (i >= num_batch) { + smallest_engine = i; + break; + } } - int num_binding = context->num_inputs() + context->num_outputs(); - std::vector buffers(num_binding); + // TODO(sami): Need an LRU here + if (smallest_engine == 0) { + if (max_cached_engines_ > cached_engine_batches_.size()) { + smallest_engine = num_batch; + cached_engine_batches_.push_back(num_batch); + VLOG(1) << "Running with batch size " << num_batch; + } else { + string s("Engine buffer is full. buffer limit= "); + StrAppend(&s, max_cached_engines_, ", current entries= "); + for (auto i : cached_engine_batches_) StrAppend(&s, i, ", "); + StrAppend(&s, "Requested batch= ", num_batch); + LOG(ERROR) << s; + ctx->SetStatus(tensorflow::errors::ResourceExhausted( + "Requested batch size is not available and engine cache is full")); + return -1; + } + } + return smallest_engine; +} - size_t binding_index; - int num_batch = 0; - for (int i = 0; i < context->num_inputs(); i++) { - // Grab the input tensor - binding_index = trt_engine_ptr_->getBindingIndex(input_nodes_[i].c_str()); +void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, + tensorflow::AsyncOpKernel::DoneCallback done) { + auto helper = new AsyncHelper(done); + tensorflow::core::ScopedUnref sc(helper); + if (calibration_mode_) { + ExecuteCalibration(ctx, helper); + return; + } + const int smallest_engine = GetEngineBatch(ctx); + if (smallest_engine < 0) return; // GetEngineBatch already set the status. + + const int num_batch = ctx->input(0).shape().dim_size(0); + auto& engine_ctx_pair = GetEngine(smallest_engine, ctx); + auto& trt_engine_ptr = engine_ctx_pair.first; + if (!trt_engine_ptr) { + LOG(WARNING) << "Engine retrieval for batch size " << num_batch + << " failed Running native segment"; + ExecuteNativeSegment(ctx, helper); + return; + } - const Tensor& input_tensor = context->input(i); + const int num_binding = ctx->num_inputs() + ctx->num_outputs(); + std::vector buffers(num_binding); + for (int i = 0; i < ctx->num_inputs(); i++) { + const string inp_name = StrCat(kInputPHName, i); + const size_t binding_index = + trt_engine_ptr->getBindingIndex(inp_name.c_str()); + + const Tensor& input_tensor = ctx->input(i); const TensorShape& input_shape = input_tensor.shape(); - if (i == 0) { - num_batch = input_shape.dim_size(0); - if (num_batch > trt_engine_ptr_->getMaxBatchSize()) { - LOG(FATAL) << "input tensor batch larger than max_batch_size: " - << trt_engine_ptr_->getMaxBatchSize(); - } - } else if (num_batch != input_shape.dim_size(0)) { - LOG(FATAL) << "input data inconsistent batch size"; - break; + if (num_batch != input_shape.dim_size(0)) { + LOG(ERROR) << "input data inconsistent batch size"; + ctx->SetStatus(tensorflow::errors::FailedPrecondition( + "Different batch sizes between input tensors")); + return; } - auto dtype = trt_engine_ptr_->getBindingDataType(binding_index); + auto dtype = trt_engine_ptr->getBindingDataType(binding_index); switch (dtype) { case nvinfer1::DataType::kFLOAT: buffers[binding_index] = (void*)(input_tensor.flat().data()); break; case nvinfer1::DataType::kHALF: - LOG(FATAL) << "half size is not supported yet!"; - break; + LOG(ERROR) << "FP16 inputs are not supported yet!"; + ctx->SetStatus(tensorflow::errors::InvalidArgument( + "FP16 inputs are not supported!")); + return; case nvinfer1::DataType::kINT8: - LOG(FATAL) << "int8 is not supported yet!"; - break; + LOG(ERROR) << "INT8 inputs are not supported yet!"; + ctx->SetStatus(tensorflow::errors::InvalidArgument( + "INT8 inputs are not supported!")); + return; default: - LOG(FATAL) << "Unknown data type: " << int(dtype); - break; + LOG(ERROR) << "Unknown TRT data type: " << int(dtype); + ctx->SetStatus(tensorflow::errors::InvalidArgument( + "Unknown ouput TRT data type! ", static_cast(dtype))); + return; } } - for (int i = 0; i < static_cast(output_nodes_.size()); i++) { - // This is bad that we have to reallocate output buffer every run. + for (int i = 0; i < ctx->num_outputs(); i++) { // Create an output tensor - binding_index = trt_engine_ptr_->getBindingIndex(output_nodes_[i].c_str()); + const string output_name = StrCat(kOutputPHName, i); + const size_t binding_index = trt_engine_ptr->getBindingIndex( + output_name.c_str()); Tensor* output_tensor = nullptr; TensorShape output_shape; if (binding_index != -1) { - auto dims = trt_engine_ptr_->getBindingDimensions(binding_index); + auto dims = trt_engine_ptr->getBindingDimensions(binding_index); std::vector trt_shape(dims.nbDims + 1); trt_shape[0] = num_batch; for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j]; - OP_REQUIRES_OK(context, - TensorShapeUtils::MakeShape( - trt_shape.data(), trt_shape.size(), &output_shape)); + OP_REQUIRES_OK( + ctx, TensorShapeUtils::MakeShape(trt_shape.data(), trt_shape.size(), + &output_shape)); } else { - LOG(FATAL) << "output node not found, at " << output_nodes_[i]; - break; + LOG(ERROR) << "output node not found, at " << output_name; + ctx->SetStatus(tensorflow::errors::Internal("output ", output_name, + " couldn't be found!")); + return; } - - OP_REQUIRES_OK(context, - context->allocate_output(i, output_shape, &output_tensor)); - auto dtype = trt_engine_ptr_->getBindingDataType(binding_index); + auto status = ctx->allocate_output(i, output_shape, &output_tensor); + if (!status.ok()) { + LOG(ERROR) << "Allocating output failed with " << status; + ctx->SetStatus(status); + return; + } + auto dtype = trt_engine_ptr->getBindingDataType(binding_index); switch (dtype) { case nvinfer1::DataType::kFLOAT: buffers[binding_index] = reinterpret_cast(output_tensor->flat().data()); break; case nvinfer1::DataType::kHALF: - LOG(FATAL) << "half size is not supported yet!"; - break; + LOG(ERROR) << "half size is not supported yet!"; + ctx->SetStatus(tensorflow::errors::InvalidArgument( + "Half outputs are not supported!")); + return; case nvinfer1::DataType::kINT8: - LOG(FATAL) << "int8 is not supported yet!"; - break; + LOG(ERROR) << "int8 is not supported yet!"; + ctx->SetStatus(tensorflow::errors::InvalidArgument( + "INT8 outputs are not supported!")); + return; default: - LOG(FATAL) << "Unknown data type: " << int(dtype); - break; + LOG(ERROR) << "Unknown TRT data type: " << static_cast(dtype); + ctx->SetStatus(tensorflow::errors::InvalidArgument( + "Unsupported output data type! ", int(dtype))); + return; } } // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files const cudaStream_t* stream = CHECK_NOTNULL( - reinterpret_cast(context->op_device_context() + reinterpret_cast(ctx->op_device_context() ->stream() ->implementation() ->CudaStreamMemberHack())); // TODO(jie): trt enqueue does not return error - auto ret = trt_execution_context_ptr_->enqueue(num_batch, &buffers[0], - *stream, nullptr); - VLOG(2) << "enqueue returns: " << ret; + auto& trt_execution_context_ptr = engine_ctx_pair.second; + auto ret = trt_execution_context_ptr->enqueue(num_batch, &buffers[0], *stream, + nullptr); + if (!ret) { + LOG(ERROR) << "Failed to enqueue batch for TRT engine: " << name(); + ctx->SetStatus(tensorflow::errors::Internal( + "Failed to enqueue batch for TRT engine: ", name())); + } // sync should be done by TF. } + TRTEngineOp::~TRTEngineOp() { - // Order matters! - trt_execution_context_ptr_.reset(); - trt_engine_ptr_.reset(); + // We need to manually destroy the engine and execution context before + // the allocator is destructed. + for (auto& eng : engine_map_) { + eng.second.first.reset(); + eng.second.second.reset(); + } allocator_.reset(); } + +nvinfer1::IGpuAllocator* TRTEngineOp::GetAllocator(OpKernelContext* ctx) { + if (allocator_) return allocator_.get(); + auto device = ctx->device(); + auto alloc = device->GetAllocator(tensorflow::AllocatorAttributes()); + if (!alloc) { + LOG(ERROR) << "Can't find device allocator for gpu device " + << device->name(); + ctx->SetStatus(tensorflow::errors::Internal( + "Can't get device allocator for device ", device->name())); + return nullptr; + } + allocator_.reset(new TRTDeviceAllocator(alloc)); + return allocator_.get(); +} + +TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, + OpKernelContext* ctx) { + static EngineCtxPair null_pair = { + TrtUniquePtrType(nullptr), + TrtUniquePtrType(nullptr)}; + // TODO(sami): This method needs to be re-written to use resource manager and + // with LRU mechanism option. + tensorflow::mutex_lock lock(engine_mutex_); + + if (static_engine_) { + if (engine_map_.size()) { + if (engine_map_.begin()->first >= batch_size) { + return engine_map_.begin()->second; + } + return null_pair; + } + TrtUniquePtrType infer(nvinfer1::createInferRuntime(logger)); +#if NV_TENSORRT_MAJOR > 3 + auto allocator = GetAllocator(ctx); + if (allocator == nullptr) { + // GetAllocator already set the Status. + return null_pair; + } + infer->setGpuAllocator(allocator); +#endif + TrtUniquePtrType static_engine( + infer->deserializeCudaEngine(serialized_segment_.c_str(), + serialized_segment_.size(), nullptr)); + auto raw_static_engine = static_engine.get(); + const auto max_batch_size = raw_static_engine->getMaxBatchSize(); + engine_map_[max_batch_size] = { + std::move(static_engine), + TrtUniquePtrType( + raw_static_engine->createExecutionContext())}; + // Runtime is safe to delete after engine creation + serialized_segment_.clear(); + if (max_batch_size < batch_size) return null_pair; + return engine_map_.at(max_batch_size); + } // static_engine_ + + // Handle the dynamic engine case. + auto engine_it = engine_map_.find(batch_size); + if (engine_it == engine_map_.end() && + engine_map_.size() < (size_t)max_cached_engines_) { + nvinfer1::IGpuAllocator* allocator = nullptr; +#if NV_TENSORRT_MAJOR > 3 + allocator = GetAllocator(ctx); + if (allocator == nullptr) { + // GetAllocator already set the Status. + return null_pair; + } +#endif + std::vector shapes; + for (int i = 0; i < ctx->num_inputs(); ++i) { + shapes.emplace_back(ctx->input(i).shape()); + } + TrtUniquePtrType engine; + bool convert_successfully = false; + VLOG(0) << name() << " Constructing a new engine with batch size " + << batch_size; + // Up to this point, calibrator_ can never be empty, since otherwise it + // means calibration_mode_ is true and this path won't get executed. + auto status = convert::ConvertGraphDefToEngine( + segment_graph_, precision_mode_, batch_size, workspace_size_, shapes, + &logger, allocator, calibrator_.get(), &engine, &convert_successfully); + if (!status.ok()) { + if (convert_successfully) { + // This means it fail to build the engine even when the network is built + // successfully, probably due to internal issues. In this case we don't + // retry in the future. + engine_map_[batch_size] = {nullptr, nullptr}; + } + LOG(ERROR) << "Engine creation for batch size " << batch_size + << " failed " << status; + ctx->SetStatus(tensorflow::errors::Internal("Engine creation failed!")); + return null_pair; + } + VLOG(1) << "Conversion is done"; + TrtUniquePtrType exec_context( + engine->createExecutionContext()); + engine_map_[batch_size] = {std::move(engine), std::move(exec_context)}; + } + return engine_map_.at(batch_size); +} + +tensorflow::Status TRTEngineOp::AllocateCalibrationResources( + tensorflow::OpKernelContext* ctx, TRTCalibrationResource** cr) { + auto cres = new TRTCalibrationResource(); + *cr = cres; + // Get the allocator. + auto alloc = ctx->device()->GetAllocator(tensorflow::AllocatorAttributes()); + if (!alloc) { + LOG(WARNING) << "Can't get device allocator will not be able to " + "allocate memory from TensorFlow memory pool"; + cres->allocator_.reset(new TRTCudaAllocator); + } else { + cres->allocator_.reset(new TRTDeviceAllocator(alloc)); + } + // Get the input shapes. + const int batch_size = ctx->input(0).dim_size(0); + const int num_inputs = ctx->num_inputs(); + std::vector shapes; + dev_tensors_.resize(num_inputs); + VLOG(1) << " Constructing calibrator"; + for (int i = 0; i < num_inputs; i++) { + // allocate workspace on device for inputs + const tensorflow::Tensor& t = ctx->input(i); + shapes.emplace_back(t.shape()); + Tensor* device_tensor; + TF_RETURN_IF_ERROR(ctx->allocate_persistent( + t.dtype(), t.shape(), &dev_tensors_.at(i), &device_tensor)); + CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes()); + void* device_address = GetTensorAddress(device_tensor); + if (device_address == nullptr) { + return tensorflow::errors::InvalidArgument( + "Unsupported data type encountered in input ", i); + } + device_buffers_.emplace( + StrCat(kInputPHName, i), + std::pair(device_address, device_tensor->TotalBytes())); + } + cres->calibrator_.reset( + new TRTInt8Calibrator(device_buffers_, batch_size, name())); + const string label(name()); + auto segment_graph = &segment_graph_; + const int cuda_gpu_id = ctx->device()->tensorflow_gpu_device_info()->gpu_id; + if (cuda_gpu_id < 0) { + LOG(ERROR) << "Can't get gpu_device_info from context->device()"; + return tensorflow::errors::InvalidArgument( + "Context->device doesn't contain device info!"); + } + const int64 workspace_size_bytes = workspace_size_; + cres->thr_.reset(new std::thread([cres, label, segment_graph, shapes, + cuda_gpu_id, workspace_size_bytes]() { + VLOG(0) << "Starting calibration thread on device " << cuda_gpu_id + << ", Calibration Resource @ " << cres; + auto err = cudaSetDevice(cuda_gpu_id); + if (err != cudaSuccess) { + // TODO(aaroey): should return error here. + LOG(ERROR) << "Couldn't set cuda device to " << cuda_gpu_id + << " in calibration thread"; + } + // ConvertGraphDefToEngine() will try to build the engine. This thread + // will loop inside buildCudaEngine() consuming the calibration data + // that is set by the TF op, and drive the builder until calibrator returns + // false. Engine is discarded after calibration table is generated + // + // TODO(aaroey): maybe setting the max batch size using the python + // calibration wrapper class. + auto s = convert::ConvertGraphDefToEngine( + *segment_graph, convert::INT8MODE, cres->calibrator_->getBatchSize(), + workspace_size_bytes, shapes, &cres->logger_, cres->allocator_.get(), + cres->calibrator_.get(), &cres->engine_, + /*convert_successfully=*/nullptr); + if (!s.ok()) { + LOG(ERROR) << "Calibration failed: " << s; + cres->calibrator_->setDone(); // Ignore further pushes + } + VLOG(1) << "Calibration loop terminated " << label; + })); + VLOG(1) << "initialized calibrator resource"; + return tensorflow::Status::OK(); +} + REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp); } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h index e613a71422852e60565ba7554516d7eace6b9cc7..6fe318be6a6bc9f01ce3b52e0430f2090b53002b 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h @@ -19,9 +19,14 @@ limitations under the License. #include #include +#include "tensorflow/contrib/tensorrt/convert/utils.h" +#include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/mutex.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -30,32 +35,95 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -class Logger; - +class TRTInt8Calibrator; +class TRTCalibrationResource; +class AsyncHelper; // TODO(Sami): Remove this file? -class TRTEngineOp : public OpKernel { + +// This OP can construct TRTEngine on the fly and if construction of engine +// fails, executes equivalent subgraph as a TensorFlow function. +class TRTEngineOp : public AsyncOpKernel { public: explicit TRTEngineOp(OpKernelConstruction* context); - void Compute(OpKernelContext* context) override; + void ComputeAsync(OpKernelContext* context, + AsyncOpKernel::DoneCallback done) override; ~TRTEngineOp(); private: - template - struct Destroyer { - void operator()(T* d) { d->destroy(); } - }; - - template - using destroyed_ptr = std::unique_ptr>; - destroyed_ptr trt_engine_ptr_; + // Execute calibration + void ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper); + + // Construct a function handle for executing native funcdef graph + Status ConstructFunctionHandle(OpKernelContext* ctx); + + // Execute replaced native segment as function Op. + void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper); + + // Allocate necessary resources for calibration + Status AllocateCalibrationResources(OpKernelContext* ctx, + TRTCalibrationResource** cr); + // TODO(samikama): context should go to a resource manager! - destroyed_ptr trt_execution_context_ptr_; + typedef std::pair, + TrtUniquePtrType> + EngineCtxPair; + EngineCtxPair& GetEngine(int batch_size, OpKernelContext* ctx); + // Return engine batch closest to input batch. + int GetEngineBatch(OpKernelContext* ctx); + + nvinfer1::IGpuAllocator* GetAllocator(OpKernelContext* ctx); + + // map to keep engines and their execution context for given batch size. + std::unordered_map engine_map_; std::vector input_nodes_; std::vector output_nodes_; - std::shared_ptr allocator_; - string serialized_engine_; + + // keep device allocator for TRT. + std::unique_ptr allocator_; + + // serialized protobuf segment or trt engine depending on static_engine_ flag. + string serialized_segment_; + + // Name of the function for TF native execution of the segment. + string funcdef_name_; + + // GraphDef representation of the segment. + GraphDef segment_graph_; + + // Lookup table for temporary staging areas of input tensors for calibration. + std::unordered_map> device_buffers_; + + // Temporary staging areas for calibration inputs. + std::vector dev_tensors_; + + // Engine Precision mode. + int precision_mode_; + + // Whether engine is constructed during the conversion or needs to be + // constructed from protobuf segment. + bool static_engine_; + + // Whether to calibrate INT8 engine. + bool calibration_mode_; + + // Whether non-batch ranks of the inputs are assumed to be fixed or not for + // engine construction. + bool fixed_input_size_; + + // Batches of the cached engines + std::vector cached_engine_batches_; + + // Maximum number of cached engines + int max_cached_engines_; + + int64 workspace_size_; + mutex engine_mutex_; + FunctionLibraryRuntime::Handle native_func_; + + // The finalized calibrator for inference. + std::unique_ptr calibrator_; }; } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc b/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc deleted file mode 100644 index 4835e5065068ec7a59995eb7f6126b31aecf6704..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" -namespace tensorflow { - -REGISTER_OP("TRTCalibOp") - .Attr("segment_nodes: list(string)") // names of the ops in segment - .Attr("segment_output_names: list(string)") // names of the output ops in - // segment - .Attr("input_names: list(string)") // names of the inputs for - // passing into tensorrt - .Attr("resource_name: string") - .Attr("InT: list({int8, float16, float32})") - .Input("in_tensor: InT") - .Output("out_tensor: InT") - .SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) { - for (int i = 0; i < c->num_inputs(); i++) { - c->set_output(i, c->input(i)); - } - return Status::OK(); - }); - -} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc index 079d73f7bec3f9a9740e455b31a259cec287f849..383635f428812984915a8c46ad3b92cc7b28a5f7 100644 --- a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc @@ -28,11 +28,19 @@ extern Status TRTEngineOpShapeInference(InferenceContext* c); } REGISTER_OP("TRTEngineOp") - .Attr("serialized_engine: string") - .Attr("input_nodes: list(string)") - .Attr("output_nodes: list(string)") - .Attr("InT: list({float32})") - .Attr("OutT: list({float32})") + .Attr("serialized_segment: string") + .Attr("input_shapes: list(shape)") + .Attr("output_shapes: list(shape)") + .Attr("segment_funcdef_name: string") + .Attr("InT: list({int8,float16,float32})") + .Attr("OutT: list({int8,float16,float32})") + .Attr("static_engine: bool = true") + .Attr("fixed_input_size: bool = true") + .Attr("cached_engine_batches: list(int) = []") + .Attr("max_cached_engines_count: int = 1") + .Attr("workspace_size_bytes: int") + .Attr("precision_mode: {'FP32', 'FP16', 'INT8', 'INT8CALIB'}") + .Attr("calibration_data: string = ''") .Input("in_tensor: InT") .Output("out_tensor: OutT") .SetShapeFn(shape_inference::TRTEngineOpShapeInference); diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py index 338475d90ea55ab2c1bb8df77f27a71a4a36a5dd..79f512dbcf6bd4d84b98cf69630778734566391c 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -21,6 +21,8 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long import six as _six from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert +from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_version +from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 @@ -29,7 +31,9 @@ from tensorflow.python.framework import errors_impl as _impl from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.grappler import tf_optimizer +from tensorflow.python.platform import tf_logging from tensorflow.python.util import compat + # pylint: enable=unused-import,line-too-long @@ -40,7 +44,10 @@ def create_inference_graph(input_graph_def, max_batch_size=1, max_workspace_size_bytes=2 << 20, precision_mode="FP32", - minimum_segment_size=3): + minimum_segment_size=3, + is_dynamic_op=False, + maximum_cached_engines=1, + cached_engine_batches=[]): """Python wrapper for the TRT transformation. Args: @@ -51,6 +58,10 @@ def create_inference_graph(input_graph_def, precision_mode: one of 'FP32', 'FP16' and 'INT8' minimum_segment_size: the minimum number of nodes required for a subgraph to be replaced by TRTEngineOp. + is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT + network and engine at run time. + maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops. + cached_engine_batches: batch sizes used to pre-create cached engines. Returns: New GraphDef with TRTEngineOps placed in graph replacing subgraphs. @@ -65,6 +76,30 @@ def create_inference_graph(input_graph_def, "It should be one of {}").format( precision_mode, "{'FP32', 'FP16', 'INT8'}")) mode = supported_precision_modes[precision_mode.upper()] + compiled_version = get_linked_tensorrt_version() + loaded_version = get_loaded_tensorrt_version() + version_mismatch = False + if loaded_version[0] < compiled_version[0]: + tf_logging.error( + "TensorRT version mismatch. Tensorflow was compiled against " + + "TensorRT %s but library loaded from environment is TensorRT %s" % + (".".join([str(x) for x in compiled_version]), + ".".join([str(x) for x in loaded_version])) + + ". Please make sure that correct version of TensorRT " + + "is available in the system and added to ldconfig or LD_LIBRARY_PATH" + ) + raise RuntimeError("Incompatible TensorRT library version") + for i in zip(loaded_version, compiled_version): + if i[0] != i[1]: + tf_logging.warn("TensorRT mismatch. Compiled against version " + + "%s, but loaded %s. Things may not work" % + (".".join([str(x) for x in compiled_version]), + ".".join([str(x) for x in loaded_version]))) + version_mismatch = True + break + if not version_mismatch: + tf_logging.info("Running against TensorRT version %s" % ".".join( + [str(x) for x in loaded_version])) def py2bytes(inp): return inp @@ -100,7 +135,9 @@ def create_inference_graph(input_graph_def, # pair or strings where first one is encoded status and the second # one is the transformed graphs protobuf string. out = trt_convert(input_graph_def_str, out_names, max_batch_size, - max_workspace_size_bytes, mode, minimum_segment_size) + max_workspace_size_bytes, mode, minimum_segment_size, + is_dynamic_op, maximum_cached_engines, + cached_engine_batches) status = to_string(out[0]) output_graph_def_string = out[1] del input_graph_def_str # Save some memory @@ -120,11 +157,12 @@ def create_inference_graph(input_graph_def, return output_graph_def -def calib_graph_to_infer_graph(calibration_graph_def): +def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False): """Convert an existing calibration graph to inference graph. Args: calibration_graph_def: the calibration GraphDef object with calibration data + is_dynamic_op: whether to create dynamic static engines from calibration Returns: New GraphDef with TRTEngineOps placed in graph replacing calibration nodes. Raises: @@ -141,9 +179,16 @@ def calib_graph_to_infer_graph(calibration_graph_def): to_string = py2string else: to_string = py3string - + is_calib_graph = False + for n in calibration_graph_def.node: + if n.op == "TRTEngineOp": + is_calib_graph = is_calib_graph or not n.attr["calibration_data"].s + if not is_calib_graph: + tf_logging.error( + "Not a calib graph. Doesn't seem to contain any calibration nodes.") + return None graph_str = calibration_graph_def.SerializeToString() - out = calib_convert(graph_str) + out = calib_convert(graph_str, is_dynamic_op) status = to_string(out[0]) output_graph_def_string = out[1] del graph_str # Save some memory diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc b/tensorflow/contrib/tensorrt/resources/trt_allocator.cc index 0f0508331c13055096714352e83fc360f0ef39b4..9f115990c3a3e6e92093e5f0d82b985af1b25482 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc +++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.cc @@ -50,7 +50,7 @@ TRTDeviceAllocator::TRTDeviceAllocator(tensorflow::Allocator* allocator) } void TRTDeviceAllocator::free(void* memory) { - VLOG(2) << "Deallocating " << memory; + VLOG(2) << "Deallocating @ " << memory; allocator_->DeallocateRaw(memory); } diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.h b/tensorflow/contrib/tensorrt/resources/trt_allocator.h index a0c2540a7698bc46a65dbd967412351bac2a4dd2..c5d2cec730f4ae97e4c6bcc19897fd9f321122a7 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_allocator.h +++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_ #define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_ - #include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/core/framework/allocator.h" @@ -52,7 +51,9 @@ class TRTDeviceAllocator : public nvinfer1::IGpuAllocator { // Allocator implementation wrapping TF device allocators. public: TRTDeviceAllocator(tensorflow::Allocator* allocator); - virtual ~TRTDeviceAllocator() {} + virtual ~TRTDeviceAllocator() { + VLOG(1) << "Destroying allocator attached to " << allocator_->Name(); + } void* allocate(uint64_t size, uint64_t alignment, uint32_t flags) override; void free(void* memory) override; diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc index dc7c93f869f5ef7c8eaa2a87eed26cfe69597fdb..32e81858b95d76a2baebb4804a1326fbbb6144c7 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" #include -#include #include #include "tensorflow/core/platform/logging.h" @@ -37,15 +36,22 @@ TRTInt8Calibrator::TRTInt8Calibrator( : batch_size_(batch_size), done_(false), dev_buffers_(dev_buffers), - calib_running_(false), + calib_running_(true), batch_is_set_(false), engine_name_(engine_name) {} +TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data) + : batch_size_(0), + done_(false), + calib_running_(false), + batch_is_set_(false), + calibration_table_(calib_data) {} + bool TRTInt8Calibrator::setBatch(const std::unordered_map& data, const cudaStream_t stream) { tensorflow::mutex_lock lock(cond_mtx_); - while ((calib_running_ || batch_is_set_) && - !done_) { // wait while calibration is running + // wait while calibration is running. + while ((calib_running_ || batch_is_set_) && !done_) { cond_.wait(lock); } if (done_) return false; @@ -59,8 +65,6 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map& data, } const auto& d = devptr->second; - // TODO(aaroey): we should not use sync copy on default stream. Make sure - // stream->ThenMemcpy() is used in future PRs. // TODO(sami,aaroey): Need to figure out a way to ensure synchronization // between stream, perhaps using a tensor? auto status = cudaMemcpyAsync(d.first, it.second, d.second, @@ -84,13 +88,11 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names, tensorflow::mutex_lock lock(cond_mtx_); calib_running_ = false; cond_.notify_all(); - while ((!batch_is_set_ && !done_)) { // wait until new batch arrives + // wait until new batch arrives + while ((!batch_is_set_ && !done_)) { cond_.wait(lock); - - } - if (done_) { - return false; } + if (done_) return false; for (int i = 0; i < num_bindings; i++) { auto it = dev_buffers_.find(names[i]); @@ -107,7 +109,9 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names, } const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) { - return nullptr; + if (calibration_table_.empty()) return nullptr; + length = calibration_table_.size(); + return calibration_table_.data(); } void TRTInt8Calibrator::setDone() { @@ -117,7 +121,11 @@ void TRTInt8Calibrator::setDone() { } void TRTInt8Calibrator::writeCalibrationCache(const void* ptr, - std::size_t length) {} + std::size_t length) { + calibration_table_ = string((const char*)ptr, length); + VLOG(1) << "Got calibration data for " << engine_name_ << " @" << ptr + << " length=" << length; +} TRTInt8Calibrator::~TRTInt8Calibrator() { VLOG(1) << "Destroying calibrator for " << engine_name_; } diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h index d77aa2c5ab184756adaee38f88180b3c128ebe03..994312d7c3c93ba04394b7d9542d261c57c5609b 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h @@ -39,29 +39,48 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { TRTInt8Calibrator( const std::unordered_map>& dev_buffers, int batch_size, string engine_name); + + TRTInt8Calibrator(const string& calibration_data); + + ~TRTInt8Calibrator(); + int getBatchSize() const override; + bool getBatch(void* bindings[], const char* names[], int num_bindings) override; + bool setBatch(const std::unordered_map& data, const cudaStream_t stream); + void setDone(); + + // If not null, calibration is skipped. const void* readCalibrationCache(std::size_t& length) override; + void writeCalibrationCache(const void* ptr, std::size_t length) override; - ~TRTInt8Calibrator(); + + const string& getCalibrationTableAsString() { return calibration_table_; } private: const int batch_size_; - tensorflow::mutex cond_mtx_; // mutex for condition_variable - tensorflow::condition_variable cond_; // condition variable to implement - // producer-consumer queue for - // calibration + + // mutex for condition_variable + tensorflow::mutex cond_mtx_; + + // condition variable to implement producer-consumer queue for calibration + tensorflow::condition_variable cond_; + + // Is calibration finished? bool done_; - const std::unordered_map> - dev_buffers_; // map to keep tensorrt input buffers and sizes keyed with - // buffer names + + // Map to keep tensorrt input buffers and sizes keyed with buffer names + const std::unordered_map> dev_buffers_; + bool calib_running_; bool batch_is_set_; + string engine_name_; + string calibration_table_; }; } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h index e3469124acd4b9f6f4dd81b9998aa60bfe469b35..b7d5ffd6748ba34c6c4ddbfbfbb44edb6bf2aca8 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_resources.h +++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "tensorflow/contrib/tensorrt/convert/utils.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/contrib/tensorrt/resources/trt_allocator.h" #include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" @@ -34,50 +35,48 @@ limitations under the License. namespace tensorflow { namespace tensorrt { + class TRTCalibrationResource : public tensorflow::ResourceBase { public: - TRTCalibrationResource() - : calibrator_(nullptr), - builder_(nullptr), - network_(nullptr), - engine_(nullptr), - logger_(nullptr), - thr_(nullptr) {} - ~TRTCalibrationResource() { VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString(); + builder_.reset(); + engine_.reset(); + // We need to manually destroy the builder and engine before the allocator + // is destroyed. + allocator_.reset(); } string DebugString() override { std::stringstream oss; - oss << " Calibrator = " << std::hex << calibrator_ << std::dec << std::endl - << " Builder = " << std::hex << builder_ << std::dec << std::endl - << " Network = " << std::hex << network_ << std::dec << std::endl - << " Engine = " << std::hex << engine_ << std::dec << std::endl - << " Logger = " << std::hex << logger_ << std::dec << std::endl - << " Allocator = " << std::hex << allocator_.get() << std::dec - << std::endl - << " Thread = " << std::hex << thr_ << std::dec << std::endl; + using std::dec; + using std::endl; + using std::hex; + oss << " Calibrator = " << hex << calibrator_.get() << dec << endl + << " Builder = " << hex << builder_.get() << dec << endl + << " Engine = " << hex << engine_.get() << dec << endl + << " Logger = " << hex << &logger_ << dec << endl + << " Allocator = " << hex << allocator_.get() << dec << endl + << " Thread = " << hex << thr_.get() << dec << endl; return oss.str(); } - TRTInt8Calibrator* calibrator_; - nvinfer1::IBuilder* builder_; - nvinfer1::INetworkDefinition* network_; - nvinfer1::ICudaEngine* engine_; - std::shared_ptr allocator_; - tensorflow::tensorrt::Logger* logger_; + std::unique_ptr calibrator_; + TrtUniquePtrType builder_; + TrtUniquePtrType engine_; + std::unique_ptr allocator_; + tensorflow::tensorrt::Logger logger_; // TODO(sami): Use threadpool threads! - std::thread* thr_; + std::unique_ptr thr_; }; -class TRTWeightStore : public tensorflow::ResourceBase { +class TRTWeightStore { public: TRTWeightStore() {} virtual ~TRTWeightStore() { VLOG(1) << "Destroying store" << DebugString(); } - string DebugString() override { + string DebugString() { std::stringstream oss; size_t len_bytes = 0; for (const auto& v : store_) { diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h index 1568dd915344e6ba982b5a5550cc5386e047ff9f..81b4bfe49fe375d19f4c7811459f38e25d2edea8 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.h +++ b/tensorflow/contrib/tensorrt/segment/segment.h @@ -29,8 +29,9 @@ namespace tensorflow { namespace tensorrt { namespace segment { -// vector of segments, each entry contains a device name and a set of nodes in -// segment +// Vector of segments, each entry contains a set of node names and a device name +// in the segment. +// TODO(aaroey): use node pointer instead of node name. using SegmentNodesVector = std::vector, string>>; struct SegmentOptions { @@ -48,6 +49,8 @@ struct SegmentOptions { // in the vector describes a subgraph by giving a set of the names of // all the NodeDefs in that subgraph. // @return the status. +// +// TODO(aaroey): remove this method. tensorflow::Status SegmentGraph( const tensorflow::GraphDef& gdef, const std::function& candidate_fn, diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc index f36495f6b69ecb2f2a8d730b9ae4919fea3c04b8..227ac120dde8c986379c687987cd1bd822d559f7 100644 --- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc @@ -29,61 +29,35 @@ namespace tensorflow { namespace shape_inference { tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) { - tensorflow::tensorrt::Logger logger; - string serialized_engine; - TF_RETURN_IF_ERROR(context->GetAttr("serialized_engine", &serialized_engine)); - nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger); - nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine( - serialized_engine.c_str(), serialized_engine.size(), - tensorrt::PluginFactoryTensorRT::GetInstance()); - - int num_batch = -1; - std::vector<::tensorflow::DataType> input_type; - TF_RETURN_IF_ERROR(context->GetAttr("InT", &input_type)); - for (size_t i = 0; i < context->num_inputs(); i++) { - // Check if input shape is legit - auto input_shape = context->input(i); - for (int j = 0; j < context->Rank(input_shape); j++) { - auto dim_handler = context->Dim(input_shape, j); - if (j == 0) { - if (i == 0) { - num_batch = context->Value(dim_handler); - } else if (num_batch != context->Value(dim_handler)) { - // TODO(jie): TensorRT engine requires consistent batch between inputs - // tensors. Segmenter should be aware of this. - LOG(FATAL) << "TensorRT engine requires consistent batch size"; - } - } - } + std::vector shapes; + for (int i = 0; i < context->num_outputs(); ++i) { + context->set_output(i, context->UnknownShape()); } - - // Arrange input here - std::vector input_nodes; - TF_RETURN_IF_ERROR(context->GetAttr("input_nodes", &input_nodes)); - - // Arrange output here - std::vector output_nodes; - TF_RETURN_IF_ERROR(context->GetAttr("output_nodes", &output_nodes)); - for (size_t i = 0; i < output_nodes.size(); i++) { - int binding_index = trt_engine->getBindingIndex(output_nodes[i].c_str()); - ShapeHandle output_shape; - std::vector dim_vec; - dim_vec.emplace_back(context->MakeDim(num_batch)); - if (binding_index != -1) { - auto dims = trt_engine->getBindingDimensions(binding_index); - for (int j = 0; j < dims.nbDims; j++) { - dim_vec.emplace_back(context->MakeDim(dims.d[j])); - } - } else { - LOG(FATAL) << "TensorRT engine cannot find binding: " << output_nodes[i]; - } - output_shape = context->MakeShape(dim_vec); - context->set_output(i, output_shape); + auto status = context->GetAttr("input_shapes", &shapes); + // it is ok to not to have shapes + if (!status.ok()) return Status::OK(); + if ((int)shapes.size() != context->num_inputs()) return Status::OK(); + bool different_input = false; + for (int i = 0; i < context->num_inputs(); ++i) { + if (shapes.at(i) != context->input_tensor(i)->shape()) + different_input = true; + } + if (different_input) return Status::OK(); + shapes.resize(0); + status = context->GetAttr("output_shapes", &shapes); + if (!status.ok()) return Status::OK(); + if ((int)shapes.size() != context->num_outputs()) return Status::OK(); + std::vector shape_handles(shapes.size()); + for (size_t i = 0; i < shapes.size(); ++i) { + status = + context->MakeShapeFromTensorShape(shapes.at(i), &shape_handles.at(i)); + if (!status.ok()) return Status::OK(); + } + for (int i = 0; i < context->num_outputs(); ++i) { + context->set_output(i, shape_handles.at(i)); } - return Status::OK(); } - } // namespace shape_inference } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py index 175ccd800686255092e241aa59568df407d6eebc..090aa8bdb0487973e186631af3b4edac48096a5f 100644 --- a/tensorflow/contrib/tensorrt/test/test_tftrt.py +++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py @@ -20,6 +20,7 @@ from __future__ import print_function import argparse import numpy as np +import six as _six # normally we should do import tensorflow as tf and then # tf.placeholder, tf.constant, tf.nn.conv2d etc but @@ -35,10 +36,75 @@ from tensorflow.python.framework import dtypes as dtypes from tensorflow.python.framework import importer as importer from tensorflow.python.framework import ops as ops from tensorflow.python.ops import array_ops as aops +from tensorflow.python.ops import math_ops as mops from tensorflow.python.ops import nn as nn from tensorflow.python.ops import nn_ops as nn_ops +def py2bytes(inp): + return inp + + +def py3bytes(inp): + return inp.encode("utf-8", errors="surrogateescape") + + +def py2string(inp): + return inp + + +def py3string(inp): + return inp.decode("utf-8") + + +if _six.PY2: + to_bytes = py2bytes + to_string = py2string +else: + to_bytes = py3bytes + to_string = py3string + + +def get_multi_engine_graph_def(mode="FP32"): + """Create a simple graph and return its graph_def.""" + dtype = dtypes.float32 + if mode.upper() == "FP16": + dtype = dtypes.float16 + else: + pass + + g = ops.Graph() + with g.as_default(): + x = aops.placeholder(shape=[None, 3, 7, 5], name="input", dtype=dtype) + with g.name_scope("Global_scope"): + with g.name_scope("first_scope"): + e = cop.constant( + np.random.randn(3, 2, 3, 4), name="weights", dtype=dtype) + conv = nn.conv2d( + input=x, + filter=e, + data_format="NCHW", + strides=[1, 1, 1, 1], + padding="VALID", + name="conv") + b = cop.constant(np.random.randn(1, 4, 1, 1), name="bias1", dtype=dtype) + t = conv * b + + b = cop.constant(np.random.randn(1, 4, 1, 1), name="bias2", dtype=dtype) + q = conv / b + edge = mops.sin(q) + edge1 = mops.cos(conv) + with g.name_scope("test_scope"): + de = edge + edge1 + t -= edge1 + q *= edge + t += q + t -= de + k = aops.squeeze(t, name="output") + print(k.dtype) + return g.as_graph_def() + + def get_simple_graph_def(): """Create a simple graph and return its graph_def.""" g = ops.Graph() @@ -65,7 +131,9 @@ def get_simple_graph_def(): def execute_graph(gdef, dumm_inp): """Run given graphdef once.""" print("executing") - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + gpu_options = None + if trt.trt_convert.get_linked_tensorrt_version()[0] == 3: + gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) sessconfig = cpb2.ConfigProto(gpu_options=gpu_options) ops.reset_default_graph() g = ops.Graph() @@ -83,7 +151,9 @@ def execute_graph(gdef, dumm_inp): # for calibration. For this test script it is random data. def execute_calibration(gdef, dumm_inp): """Run given calibration graph multiple times.""" - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + gpu_options = None + if trt.trt_convert.get_linked_tensorrt_version()[0] == 3: + gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) ops.reset_default_graph() g = ops.Graph() with g.as_default(): @@ -100,12 +170,17 @@ def execute_calibration(gdef, dumm_inp): return val -def user(run_graph=execute_graph, run_calibration=execute_calibration): +def user(multi_engine, + run_graph=execute_graph, + run_calibration=execute_calibration): """Example function that converts a graph to TFTRT graph.""" - - inp_dims = (100, 24, 24, 2) + if multi_engine: + inp_dims = (2, 3, 7, 5) + orig_graph = get_multi_engine_graph_def() + else: + inp_dims = (100, 24, 24, 2) + orig_graph = get_simple_graph_def() # use a frozen graph for inference dummy_input = np.random.random_sample(inp_dims) - orig_graph = get_simple_graph_def() # use a frozen graph for inference # Get optimized graph trt_graph = trt.create_inference_graph( input_graph_def=orig_graph, @@ -113,8 +188,10 @@ def user(run_graph=execute_graph, run_calibration=execute_calibration): max_batch_size=inp_dims[0], max_workspace_size_bytes=1 << 25, precision_mode="FP32", # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2 # minimum number of nodes in an engine - ) + minimum_segment_size=2, # minimum number of nodes in an engine + is_dynamic_op=False, + maximum_cached_engines=1, + cached_engine_batches=[]) o1 = run_graph(orig_graph, dummy_input) o2 = run_graph(trt_graph, dummy_input) o3 = run_graph(trt_graph, dummy_input) @@ -126,40 +203,51 @@ def user(run_graph=execute_graph, run_calibration=execute_calibration): max_batch_size=inp_dims[0], max_workspace_size_bytes=1 << 25, precision_mode="FP16", # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2 # minimum number of nodes in an engine - ) + minimum_segment_size=2, # minimum number of nodes in an engine + is_dynamic_op=False, + maximum_cached_engines=1, + cached_engine_batches=[]) int8_calib_gdef = trt.create_inference_graph( input_graph_def=orig_graph, outputs=["output"], max_batch_size=inp_dims[0], max_workspace_size_bytes=1 << 25, precision_mode="INT8", # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2 # minimum number of nodes in an engine - ) + minimum_segment_size=2, # minimum number of nodes in an engine + is_dynamic_op=False, + maximum_cached_engines=1, + cached_engine_batches=[]) o4 = run_graph(fp16_graph, dummy_input) _ = run_calibration(int8_calib_gdef, dummy_input) int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef) o5 = run_graph(int8_graph, dummy_input) - assert np.allclose(o1, o4) - assert np.allclose(o1, o5) + print("Is FP32 == FP16? %s (False is possible)" % np.allclose(o1, o4)) + print("Is FP32 == INT8? %s (False is possible)" % np.allclose(o1, o5)) print("Pass") -def auto(): +def auto(multi_engine): """Run the conversion as an optimization pass.""" - inp_dims = (100, 24, 24, 2) + if multi_engine: + inp_dims = (2, 3, 7, 5) + orig_graph = get_multi_engine_graph_def() + else: + inp_dims = (100, 24, 24, 2) + orig_graph = get_simple_graph_def() # use a frozen graph for inference dummy_input = np.random.random_sample(inp_dims) - orig_graph = get_simple_graph_def() opt_config = rwpb2.RewriterConfig() + opt_config.meta_optimizer_iterations = opt_config.ONE opt_config.optimizers.extend(["constfold", "layout"]) custom_op = opt_config.custom_optimizers.add() custom_op.name = "TensorRTOptimizer" custom_op.parameter_map["minimum_segment_size"].i = 3 - custom_op.parameter_map["precision_mode"].s = "FP32" + custom_op.parameter_map["precision_mode"].s = to_bytes("FP32") custom_op.parameter_map["max_batch_size"].i = inp_dims[0] custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25 print(custom_op) - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + gpu_options = None + if trt.trt_convert.get_linked_tensorrt_version()[0] == 3: + gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) graph_options = cpb2.GraphOptions(rewrite_options=opt_config) sessconfig = cpb2.ConfigProto( gpu_options=gpu_options, graph_options=graph_options) @@ -168,7 +256,7 @@ def auto(): ops.reset_default_graph() with g.as_default(): inp, out = importer.import_graph_def( - graph_def=orig_graph, return_elements=["input", "output"]) + graph_def=orig_graph, return_elements=["input", "output"], name="") inp = inp.outputs[0] out = out.outputs[0] with csess.Session(config=sessconfig, graph=g) as sess: @@ -186,8 +274,14 @@ if "__main__" in __name__: action="store_true", help="Do TRT conversion automatically", default=False) + P.add_argument( + "--multi-engine", + "-m", + action="store_true", + help="Use a graph that will result in 2 engines", + default=False) flags, unparsed = P.parse_known_args() if flags.automatic: - auto() + auto(flags.multi_engine) else: - user() + user(flags.multi_engine) diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i index 46480e99a113afb34702b0ecd71468d4bdc83f98..d51a0b59e22cb063b380808f5887538e0294daff 100644 --- a/tensorflow/contrib/tensorrt/trt_conversion.i +++ b/tensorflow/contrib/tensorrt/trt_conversion.i @@ -48,12 +48,53 @@ PyObject* pair_helper(std::pair* in) { } return tuple; } + +struct version_struct{ + int vmajor; + int vminor; + int vpatch; +}; + +PyObject* version_helper(version_struct* in) { + PyObject *tuple(nullptr); + tuple = Py_BuildValue("(iii)", in->vmajor, in->vminor, in->vpatch); + if (!tuple) { + if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, + "Tuple creation from version structure failed!"); + } + return NULL; + } + return tuple; +} +/* Define converters for vector */ +template<> +bool _PyObjAs(PyObject *pyobj, int* dest) { + *dest = PyLong_AsLong(pyobj); + return true; +} + +template<> +PyObject *_PyObjFrom(const int& src) { + return PyLong_FromLong(src); +} + %} + +_LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong); + %typemap(out) std::pair { PyObject *tuple = pair_helper(&$1); if (!tuple) SWIG_fail; $result = tuple; } + +%typemap(out) version_struct { + PyObject *tuple = version_helper(&$1); + if (!tuple) SWIG_fail; + $result = tuple; +} + %{ #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -65,6 +106,8 @@ PyObject* pair_helper(std::pair* in) { %unignore tensorflow; %unignore trt_convert; %unignore calib_convert; +%unignore get_linked_tensorrt_version; +%unignore get_loaded_tensorrt_version; %{ @@ -74,7 +117,10 @@ std::pair trt_convert( size_t max_batch_size, size_t max_workspace_size_bytes, int precision_mode, - int minimum_segment_size + int minimum_segment_size, + bool is_dyn_op, + int max_cached_engines, + std::vector cached_engine_batches // Unfortunately we can't use TF_Status here since it // is in c/c_api and brings in a lot of other libraries // which in turn declare ops. These ops are included @@ -102,11 +148,12 @@ std::pair trt_convert( out_status = "InvalidArgument;Size of the output_names vector is 0"; return std::pair{out_status, ""}; } - tensorflow::GraphDef outGraph; + tensorflow::GraphDef out_graph; tensorflow::Status conversion_status = tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT( graph_def, output_names, max_batch_size, max_workspace_size_bytes, - &outGraph, precision_mode, minimum_segment_size); + &out_graph, precision_mode, minimum_segment_size, + is_dyn_op, max_cached_engines, cached_engine_batches); if (!conversion_status.ok()) { auto retCode = (int)conversion_status.code(); char buff[2000]; @@ -116,7 +163,7 @@ std::pair trt_convert( return std::pair{out_status, ""}; } string result; - if (!outGraph.SerializeToString(&result)) { + if (!out_graph.SerializeToString(&result)) { out_status = "InvalidArgument;Couldn't serialize output as a GraphDef"; return std::pair{out_status, ""}; } @@ -128,7 +175,8 @@ std::pair trt_convert( #endif // GOOGLE_CUDA && GOOGLE_TENSORRT } -std::pair calib_convert(string graph_def_string // const tensorflow::GraphDef& +std::pair calib_convert( + string graph_def_string, bool is_dyn_op // unfortunately we can't use TF_Status here since it // is in c/c_api and brings in a lot of other libraries // which in turn declare ops. These ops are included @@ -147,11 +195,11 @@ std::pair calib_convert(string graph_def_string // const tenso out_status = "InvalidArgument;Couldn't interpret input as a GraphDef"; return std::pair{out_status, ""}; } - - tensorflow::GraphDef outGraph; + graph_def_string.resize(0); + tensorflow::GraphDef out_graph; tensorflow::Status conversion_status = - tensorflow::tensorrt::convert::ConvertCalibGraphToInferGraph(graph_def, - &outGraph); + tensorflow::tensorrt::convert::ConvertCalibGraphToInferGraph( + graph_def, &out_graph, is_dyn_op); if (!conversion_status.ok()) { auto retCode = (int)conversion_status.code(); char buff[2000]; @@ -161,7 +209,7 @@ std::pair calib_convert(string graph_def_string // const tenso return std::pair{out_status, ""}; } string result; - if (!outGraph.SerializeToString(&result)) { + if (!out_graph.SerializeToString(&result)) { out_status = "InvalidArgument;Couldn't serialize output as a GraphDef"; return std::pair{out_status, ""}; } @@ -172,15 +220,39 @@ std::pair calib_convert(string graph_def_string // const tenso return std::pair{"9;TensorRT is not enabled!", ""}; #endif // GOOGLE_CUDA && GOOGLE_TENSORRT } + +version_struct get_linked_tensorrt_version(){ + // Return the version at the link time. + const auto &lv = tensorflow::tensorrt::convert::GetLinkedTensorRTVersion(); + version_struct s; + s.vmajor = lv[0]; + s.vminor = lv[1]; + s.vpatch = lv[2]; + return s; +} +version_struct get_loaded_tensorrt_version(){ + // Return the version from the loaded library. + const auto &lv = tensorflow::tensorrt::convert::GetLoadedTensorRTVersion(); + version_struct s; + s.vmajor = lv[0]; + s.vminor = lv[1]; + s.vpatch = lv[2]; + return s; +} + %} -std::pair calib_convert(string graph_def_string); +std::pair calib_convert(string graph_def_string, bool is_dyn_op); std::pair trt_convert(string graph_def_string, std::vector output_names, size_t max_batch_size, size_t max_workspace_size_bytes, - int precision_mode, int minimum_segment_size); - + int precision_mode, int minimum_segment_size, + bool is_dyn_op, + int max_cached_engines, + std::vector cached_engine_batches); +version_struct get_linked_tensorrt_version(); +version_struct get_loaded_tensorrt_version(); %unignoreall diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py index a28a5872b850b51630240bdeb3ff22f372613523..f236329fdb038ba5ab432c6b97f44bda7ccfe815 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head.py @@ -132,7 +132,8 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce loss=model_outputs.loss, mode=mode, eval_metric_ops=metrics, - predictions={}) + # needed for custom metrics. + predictions=model_outputs.predictions) def _predict_ops(self, features): """Add ops for prediction to the graph.""" @@ -210,12 +211,12 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce def create_estimator_spec(self, features, mode, labels=None): """Performs basic error checking and returns an EstimatorSpec.""" with ops.name_scope(self._name, "head"): - if labels: + if labels is not None and labels != {}: # for better error messages. raise ValueError( - "The model received a `labels` dictionary, which is " - "not supported. Pass '{}' and '{}' as " - "features.".format(feature_keys.TrainEvalFeatures.TIMES, - feature_keys.TrainEvalFeatures.VALUES)) + "The model received a `labels`, which is not supported. " + "Pass '{}' and '{}' as features.".format( + feature_keys.TrainEvalFeatures.TIMES, + feature_keys.TrainEvalFeatures.VALUES)) del labels features = { name: self._convert_feature_to_tensor(name=name, value=value) diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py index c606db76a668235ab6a837159b9dec072b5fd801..ed8f29c321719e552c25f4d2183fdf4eb282e4b7 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy import six +from tensorflow.contrib.estimator.python.estimator import extenders from tensorflow.contrib.timeseries.examples import lstm as lstm_example from tensorflow.contrib.timeseries.python.timeseries import estimators as ts_estimators from tensorflow.contrib.timeseries.python.timeseries import feature_keys @@ -35,6 +36,7 @@ from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics from tensorflow.python.ops import variables @@ -53,9 +55,12 @@ class HeadTest(test.TestCase): model_fn = _stub_model_fn() for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL, estimator_lib.ModeKeys.PREDICT]: - with self.assertRaisesRegexp(ValueError, "labels"): + with self.assertRaisesRegexp(ValueError, "received a `labels`"): model_fn(features={}, labels={"a": "b"}, mode=mode) + with self.assertRaisesRegexp(ValueError, "received a `labels`"): + model_fn(features={}, labels=array_ops.zeros([]), mode=mode) + def test_unknown_mode(self): model_fn = _stub_model_fn() with self.assertRaisesRegexp(ValueError, "Unknown mode 'Not a mode'"): @@ -128,6 +133,44 @@ class EvaluationMetricsTests(test.TestCase): coordinator.request_stop() coordinator.join() + def test_custom_metrics(self): + """Tests that the custom metrics can be applied to the estimator.""" + model_dir = self.get_temp_dir() + estimator = ts_estimators.TimeSeriesRegressor( + model=lstm_example._LSTMModel(num_features=1, num_units=4), + optimizer=adam.AdamOptimizer(0.001), + config=estimator_lib.RunConfig(tf_random_seed=4), + model_dir=model_dir) + + def input_fn(): + return { + feature_keys.TrainEvalFeatures.TIMES: [[1, 2, 3], [7, 8, 9]], + feature_keys.TrainEvalFeatures.VALUES: + numpy.array([[[0.], [1.], [0.]], [[2.], [3.], [2.]]]) + } + + def metrics_fn(predictions, features): + # checking that the inputs are properly passed. + predict = predictions["mean"] + target = features[feature_keys.TrainEvalFeatures.VALUES][:, -1, 0] + return { + "plain_boring_metric386": + (math_ops.reduce_mean(math_ops.abs(predict - target)), + control_flow_ops.no_op()), + "fun_metric101": (math_ops.reduce_sum(predict + target), + control_flow_ops.no_op()), + } + + # Evaluation without training is enough for testing custom metrics. + estimator = extenders.add_metrics(estimator, metrics_fn) + evaluation = estimator.evaluate(input_fn, steps=1) + self.assertIn("plain_boring_metric386", evaluation) + self.assertIn("fun_metric101", evaluation) + # The values are deterministic because of fixed tf_random_seed. + # However if they become flaky, remove such exacts comparisons. + self.assertAllClose(evaluation["plain_boring_metric386"], 1.130380) + self.assertAllClose(evaluation["fun_metric101"], 10.435442) + class _StubModel(object): num_features = 3 diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index f84ff1bfe9b014733205a8e51b43f79c63b227cb..16696793bc2dab977a3dbbfa338e33e5771d0699 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -181,6 +181,7 @@ py_library( ":datasets", ":profiler", ":tpu_py", + "//tensorflow/contrib/cluster_resolver:tpu_cluster_resolver_py", "//tensorflow/contrib/tpu/proto:compilation_result_proto_py", "//tensorflow/contrib/tpu/proto:topology_proto_py", "//tensorflow/core:protos_all_py", diff --git a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc index d389050e67f9a9e48b91583e5088058ec4e2832f..06553929dc44ca1f75ce64532a4dcdf1c8aae3eb 100644 --- a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc +++ b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc @@ -23,15 +23,23 @@ REGISTER_OP("CrossReplicaSum") .Input("input: T") .Output("output: T") .Attr("T: {bfloat16, float}") + .Attr("group_assignment: list(int) = []") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( An Op to sum inputs across replicated TPU instances. Each -instance supplies its own input, and the output of each is the sum of -all the inputs. +instance supplies its own input. If group_assignment is empty, the output of +each is the sum of all the inputs, otherwise the output of each is the sum of +the inputs belonging to the same group. + +For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing +group_assignment=`[0,1,0,1]` sets `A, C` as group 0, and `B, D` as group 1. +Thus we get the outputs: `[A+C, B+D, A+C, B+D]`. input: The local input to the sum. output: The sum of all the distributed inputs. T: The type of elements to be summed. +group_assignment: The list of group ids. `group_assignment[i]` represents the + group id of replica i. )doc"); } // namespace tensorflow diff --git a/tensorflow/contrib/tpu/ops/replication_ops.cc b/tensorflow/contrib/tpu/ops/replication_ops.cc index ab2a7a0d4bec48d6b3b459bb3144e8ddae614ca0..15a2bb17a93212afe9ce5604a28d9dba5825f7d4 100644 --- a/tensorflow/contrib/tpu/ops/replication_ops.cc +++ b/tensorflow/contrib/tpu/ops/replication_ops.cc @@ -44,6 +44,27 @@ REGISTER_OP("TPUReplicatedInput") " with other shapes."); } c->set_output(0, cur); + + // If this is a resource, unify the resource shapes. + DataType dtype; + TF_RETURN_IF_ERROR(c->GetAttr("T", &dtype)); + if (dtype == DT_RESOURCE) { + const std::vector* shapes_and_types = + nullptr; + for (int i = c->num_inputs() - 1; i >= 0; --i) { + if (shapes_and_types) { + // The return value of MergeInputHandleShapesAndTypes indicates + // the shape was refined, not that there was an error. + // TODO(phawkins): there seems to be no way to discover errors. + (void)c->MergeInputHandleShapesAndTypes(i, *shapes_and_types); + } else { + shapes_and_types = c->input_handle_shapes_and_types(i); + } + } + if (shapes_and_types) { + c->set_output_handle_shapes_and_types(0, *shapes_and_types); + } + } return Status::OK(); }) .Doc( diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD index dbf1ab6bbf0ddc7429d8e19279451eb862981e0c..38d1c3049ef7185f2f9f448361029d066678cdae 100644 --- a/tensorflow/contrib/tpu/profiler/BUILD +++ b/tensorflow/contrib/tpu/profiler/BUILD @@ -49,11 +49,11 @@ tf_cc_binary( ":tpu_profiler_analysis_proto_cc", ":tpu_profiler_proto_cc", ":version", + "//tensorflow:grpc++", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/platform/cloud:gcs_file_system", - "@grpc//:grpc++_unsecure", ], ) diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc index 99485322c6b9434f4c1700b9e2a6af00a65f794f..f80f5652af79d410946971573ae160fdd0b85f6d 100644 --- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc @@ -18,7 +18,7 @@ limitations under the License. // Initiates a TPU profiling on the TPUProfiler service at service_addr, // receives and dumps the profile data to a tensorboard log directory. -#include "grpc++/grpc++.h" +#include "grpcpp/grpcpp.h" #include #include diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py index 508c7a842fb82ec080082d7e7f02f8d2f2a79447..7f1d25732e21b5dea4e605f6caa141ca9d3d02c6 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py @@ -35,19 +35,19 @@ flags.DEFINE_string( None, help='GCE zone where the Cloud TPU is located in. If not specified, we ' 'will attempt to automatically detect the GCE project from metadata.') -flags.DEFINE_string('tpu_name', None, +flags.DEFINE_string('tpu', None, 'Name of the Cloud TPU for Cluster Resolvers. You must ' 'specify either this flag or --service_addr.') # Tool specific parameters flags.DEFINE_string( 'service_addr', None, 'Address of TPU profiler service e.g. ' - 'localhost:8466, you must specify either this flag or --tpu_name.') + 'localhost:8466, you must specify either this flag or --tpu.') flags.DEFINE_string( 'workers_list', None, 'The list of worker TPUs that we are about to profile' - ' e.g. 10.0.1.2, 10.0.1.3. You can specify this flag with --tpu_name or ' + ' e.g. 10.0.1.2, 10.0.1.3. You can specify this flag with --tpu or ' '--service_addr to profile a subset of tpu nodes. You can also use only' - '--tpu_name and leave this flag unspecified to profile all the tpus.') + '--tpu and leave this flag unspecified to profile all the tpus.') flags.DEFINE_string('logdir', None, 'Path of TensorBoard log directory e.g. /tmp/tb_log, ' 'gs://tb_bucket') @@ -76,19 +76,19 @@ def run_main(): def main(unused_argv=None): tf.logging.set_verbosity(tf.logging.INFO) - if FLAGS.service_addr is None and FLAGS.tpu_name is None: - sys.exit('You must specify either --service_addr or --tpu_name.') + if FLAGS.service_addr is None and FLAGS.tpu is None: + sys.exit('You must specify either --service_addr or --tpu.') tpu_cluster_resolver = None if FLAGS.service_addr is not None: - if FLAGS.tpu_name is not None: - tf.logging.warn('Both --service_addr and --tpu_name are set. Ignoring ' - '--tpu_name and using --service_addr.') + if FLAGS.tpu is not None: + tf.logging.warn('Both --service_addr and --tpu are set. Ignoring ' + '--tpu and using --service_addr.') service_addr = FLAGS.service_addr else: tpu_cluster_resolver = ( tf.contrib.cluster_resolver.TPUClusterResolver( - [FLAGS.tpu_name], + [FLAGS.tpu], zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)) service_addr = tpu_cluster_resolver.get_master() diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py index ebd478fd02295108b9d2454963eb06165828b523..f97a972f01a3ba5582df3675439aa962886f796e 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py @@ -20,7 +20,7 @@ from __future__ import print_function from setuptools import setup -_VERSION = '1.6.0' +_VERSION = '1.7.0' CONSOLE_SCRIPTS = [ 'capture_tpu_profile=cloud_tpu_profiler.main:run_main', @@ -46,7 +46,7 @@ setup( # 3 - Alpha # 4 - Beta # 5 - Production/Stable - 'Development Status :: 4 - Beta', + 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', 'Intended Audience :: Education', 'Intended Audience :: Science/Research', diff --git a/tensorflow/contrib/tpu/profiler/version.h b/tensorflow/contrib/tpu/profiler/version.h index 618479e1a6ccf26a4103ea1f182b662d7d9998da..bd9ba6697edd9ef14dd3af0d2c9b77df9ec6917a 100644 --- a/tensorflow/contrib/tpu/profiler/version.h +++ b/tensorflow/contrib/tpu/profiler/version.h @@ -16,6 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_ #define TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_ -#define TPU_PROFILER_VERSION "1.6.0" +#define TPU_PROFILER_VERSION "1.7.0" #endif // TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_ diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py index 14c63a79763300dcfe8d6c8e09b90f8e9c772358..bf442d9116d2ceca499ffc66258c64b5b94dd881 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py @@ -38,9 +38,8 @@ if platform.system() != "Windows": @ops.RegisterGradient("CrossReplicaSum") def _cross_replica_sum_grad(op, grad): - del op # Unused # The gradient of a cross replica sum is also a cross-replica sum. - return gen_tpu_ops.cross_replica_sum(grad) + return gen_tpu_ops.cross_replica_sum(grad, op.get_attr("group_assignment")) # This extra type checking exists to give a more helpful error message in # the common case that uint8 and int64 values are infed. Remove when both diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index f1a11fa6548b87d6222a97c72b8db5442c8ef774..293e162059205cad572f0ca78217985b6932a239 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -51,6 +51,7 @@ import collections import re import time +from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result from tensorflow.contrib.tpu.python.ops import tpu_ops @@ -368,10 +369,27 @@ class TPUFunction(object): @experimental -def setup_tpu_session(master): - """Initializes and returns a Keras/TF session connected the TPU `master`.""" +def setup_tpu_session(tpu_name_or_address): + """Initializes and returns a Keras/TF session connected the TPU `master`. + + Args: + tpu_name_or_address: A string that is either the name of the Cloud TPU, + the grpc address of the Cloud TPU, or (Googlers only) the BNS name of the + Cloud TPU. If tpu_name_or_address is None, the TPUClusterResolver will + examine the environment to determine a potential Cloud TPU to use. + + Returns: + A `tf.Session`. + """ + cluster_resolver = tpu_cluster_resolver.TPUClusterResolver( + tpu_name_or_address) + cluster_spec = cluster_resolver.cluster_spec() session = tf_session.Session( - target=master, config=config_pb2.ConfigProto(isolate_session_state=True)) + target=cluster_resolver.master(), + config=config_pb2.ConfigProto( + isolate_session_state=True)) + if cluster_spec: + session.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) K.set_session(session) K.get_session().run(tpu.initialize_system()) return session diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 71a50126910568ab05baeb0bedf4b78d40ff3bf3..dc473c5846aafc5a92756dfb8259f7f8dc14b98d 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -591,16 +591,22 @@ def split_compile_and_replicate(computation, with tpu_function.tpu_shard_context( num_replicas), ops.control_dependencies([metadata]): - # The EncapsulateTPUComputations rewrite needs to identify the - # replicated arguments inside each computation. Adds identity operators - # tagged with an attribute _tpu_replicated_input to identify the - # replicated inputs. + # For backward compatibility reasons, we tag replicated inputs with the + # _tpu_replicated_input attribute. This does nothing and exists only for + # backward compatibility. + # TODO(phawkins): delete the attr_scope after 6/28/2018. # pylint: disable=protected-access - with graph._attr_scope({"_tpu_replicated_input": - attr_value_pb2.AttrValue(b=True)}): + with graph._attr_scope({ + "_tpu_replicated_input": attr_value_pb2.AttrValue(b=True) + }): + # Add identity ops so even unused inputs are "consumed" by the + # computation. This is to avoid orphaned TPUReplicatedInput nodes. + # TODO(phawkins): consider instead pruning unused TPUReplicatedInput + # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs. computation_inputs = [ array_ops.identity(x, name="replicated_input_{}".format(i)) - for i, x in enumerate(computation_inputs)] + for i, x in enumerate(computation_inputs) + ] # pylint: enable=protected-access # If there is an infeed queue, adds the dequeued values to the @@ -623,10 +629,16 @@ def split_compile_and_replicate(computation, vscope.set_use_resource(saved_use_resource) + # If the computation returns `None`, make it an empty tuple. + if outputs is None: + outputs = tuple() # If the computation only returned one value, makes it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs,) + # Append `no_op` here so that fetching any return value of this function + # will trigger TPUExecute node. + outputs += (control_flow_ops.no_op(),) try: with ops.device(core(0)): outputs = [ diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index 5b9aeaa8797b92b4cc596744812f440607054dce..c4c69902f95e73e90832b3fd6538d73e474e330a 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -384,9 +384,7 @@ class _InternalTPUContext(object): # On TPU if self.is_input_sharded_per_core() or ( self.is_input_per_host_with_iterators()): - # We prohibit per core input sharding for the model parallelism case, - # therefore it is safe to use num_cores here. - return global_batch_size // self.num_cores + return global_batch_size // self.num_replicas else: return global_batch_size // self.num_hosts @@ -484,25 +482,27 @@ class _InternalTPUContext(object): return _placement_function - @property - def tpu_ordinal_function(self): + def tpu_ordinal_function(self, host_id): """Returns the TPU ordinal fn.""" - def _tpu_ordinal_function(index): + def _tpu_ordinal_function(shard_index_in_host): """Return the TPU ordinal associated with a shard. Required because the enqueue ops are placed on CPU. Args: - index: the shard index + shard_index_in_host: the shard index Returns: The ordinal of the TPU device the shard's infeed should be placed on. """ if self.model_parallelism_enabled: - return self.device_assignment.tpu_ordinal(replica=index) + # We put both enqueue/dequeue ops at tpu.core(0) in each replica. + replica = self.device_assignment.lookup_replicas( + host_id, (0, 0, 0))[shard_index_in_host] + return self.device_assignment.tpu_ordinal(replica=replica) else: - return index % self.num_of_cores_per_host + return shard_index_in_host % self.num_of_cores_per_host return _tpu_ordinal_function diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index f63e9e8bdab90cf3e174ba5813586e9ac1ac7dc0..85ea4d3df3beaff9f21b54af8a4749dc98c4b738 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -122,6 +122,33 @@ def _create_global_step(graph): def _create_or_get_iterations_per_loop(): + """Creates or gets the iterations_per_loop variable. + + In TPUEstimator, the user provided computation, the model_fn, is wrapped + inside a tf.while_loop for peak performance. The iterations of the loop are + specified by this variable, which adjusts its value on the CPU after each TPU + program execution and before the next TPU execution. + + The purpose of using a variable, rather then a constant, is to allow + TPUEstimator adapt the TPU training iterations according to the final steps + specified by users. For example, if the user sets the iterations_per_loop as 4 + in TPUConfig and steps as 10 in TPUEstimator.train(), the iterations_per_loop + variable will have the following value before each TPU training. + + - 1-th TPU execution: iterations_per_loop = 4 + - 2-th TPU execution: iterations_per_loop = 4 + - 3-th TPU execution: iterations_per_loop = 2 + + As model_fn increases the global step once per train_op invocation, the global + step is 10 after all TPU executions, matching the steps=10 inputs passed in by + users. + + Returns: + A TF non-trainable resource variable. + + Raises: + RuntimeError: If multi iterations_per_loop variables were found. + """ graph = ops.get_default_graph() collection_name = '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR) iter_vars = graph.get_collection(collection_name) @@ -388,20 +415,21 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): return def _cancel_session(): - # Close the session to avoid the main thread from hanging. If input - # pipeline triggers any error, the infeed thread dies but the main thread - # for TPU computation waits for the infeed enqueue forever. Close the - # Session to cancel the main thread Session.run execution. - # - # We sleep for a few seconds before closing to give some time - # for the TPU compilation error, if any, propagating, from TPU to CPU - # host. Compilation errors should be reported by the main thread so that - # the program can be interrupted and users can take action. Due to a race - # condition, the infeed thread might see an error first. Closing the - # session here immediately would result in a session cancellation - # exception in the main thread, instead of the expected compile error. - # User code that depends on having the proper exception type will - # therefore be confused. + """Close the session to avoid the main thread from hanging. + + If input pipeline triggers any error, the infeed thread dies but the main + thread for TPU computation waits for the infeed enqueue forever. Close the + Session to cancel the main thread Session.run execution. + + We sleep for a few seconds before closing to give some time for the TPU + compilation error, if any, propagating, from TPU to CPU host. Compilation + errors should be reported by the main thread so that the program can be + interrupted and users can take action. Due to a race condition, the + infeed thread might see an error first. Closing the session here + immediately would result in a session cancellation exception in the main + thread, instead of the expected compile error. User code that depends on + having the proper exception type will therefore be confused. + """ time.sleep(5) # If the main session is still running, the infeed/outfeed errors are @@ -636,6 +664,7 @@ def generate_per_core_enqueue_ops_fn_for_host( ctx, input_fn, inputs_structure_recorder, host_device, host_id): """Generates infeed enqueue ops for per-core input_fn on a single host.""" captured_infeed_queue = _CapturedObject() + tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) def enqueue_ops_fn(): """A fn returns enqueue_ops.""" @@ -671,7 +700,7 @@ def generate_per_core_enqueue_ops_fn_for_host( per_host_sharded_inputs) per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( - per_host_sharded_inputs, tpu_ordinal_function=ctx.tpu_ordinal_function) + per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl) return per_host_enqueue_ops return enqueue_ops_fn, captured_infeed_queue @@ -706,21 +735,18 @@ def generate_per_host_enqueue_ops_fn_for_host( if is_dataset: hooks.append(inputs.dataset_initializer_hook()) - # TODO(ylc): Refactoring the code to merge the tpu ordinal logic here and the - # _InternalTPUContext.tpu_ordinal_function. We should either introduce another - # abstraction or a different helper method. - def _tpu_ordinal_function_impl(shard_index_in_host): - # We put both enqueue/dequeue op at tpu.core(0) in each replica. - replica = ctx.device_assignment.lookup_replicas( - host_id, (0, 0, 0))[shard_index_in_host] - return ctx.device_assignment.tpu_ordinal(replica=replica) - - if ctx.model_parallelism_enabled: - tpu_ordinal_function = _tpu_ordinal_function_impl - else: - tpu_ordinal_function = None + tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) def enqueue_ops_fn(): + """A Fn returning the TPU infeed enqueue ops. + + By providing as a Fn, it can be invoked inside the tf.while_loop such that + the input pipeline for multiple iterations can be executed by one + Session.run call. + + Returns: + list of dict of ops. + """ with ops.device(device): num_of_replicas_per_host = ctx.num_of_replicas_per_host # Convert user input to features and labels. If the user returns a @@ -745,7 +771,7 @@ def generate_per_host_enqueue_ops_fn_for_host( infeed_queue.split_inputs_and_generate_enqueue_ops( unsharded_tensor_list, placement_function=lambda x: device, - tpu_ordinal_function=tpu_ordinal_function)) + tpu_ordinal_function=tpu_ordinal_function_impl)) if signals is None: return per_host_enqueue_ops else: @@ -779,6 +805,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host( raise TypeError('Most PREDICT not yet supported in PER_HOST_V2 mode.') hooks.append(inputs.dataset_initializer_hook()) + tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) def enqueue_ops_fn(): """Generates the per_host enqueue ops.""" @@ -809,7 +836,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host( per_host_sharded_inputs) per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( - per_host_sharded_inputs, tpu_ordinal_function=ctx.tpu_ordinal_function) + per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl) return per_host_enqueue_ops return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset @@ -1095,10 +1122,16 @@ class _InputPipeline(object): return enqueue_ops, all_hooks, run_infeed_loop_on_coordinator def _validate_input_pipeline(self): - # Perform some sanity checks to log user friendly information. We should - # error out to give users better error message. But, if - # _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break - # user code, so, log a warning. + """Validates the input pipeline. + + Perform some sanity checks to log user friendly information. We should + error out to give users better error message. But, if + _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break + user code, so, log a warning. + + Raises: + RuntimeError: If the validation failed. + """ if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS): err_msg = ('Input pipeline contains one or more QueueRunners. ' 'It could be slow and not scalable. Please consider ' @@ -1300,8 +1333,55 @@ class _ModelFnWrapper(object): key, tensor)) return predictions + def _validate_model_features_and_labels(self, + features, + labels, + is_export_mode): + """Validates that the features and labels for the model function are valid. + + A valid features/labels object is the one with: + - Type: Tensor or a dictionary of Tensors + - Static shape if is_export_mode is False. + + Args: + features: the features that would be input to the model function. + labels: the labels that would be input to the model function. + is_export_mode: boolean value specifying if in export mode. + + Raises: + TypeError: If features/labels are not of the correct type. + ValueError: If features/labels have dynamic shape. + """ + + def validate(obj, obj_name): + """Helper validate function.""" + if not isinstance(obj, ops.Tensor) and not isinstance(obj, dict): + raise TypeError( + 'The {} to the model returned by input_fn must be either a Tensor ' + 'or a dictionary of Tensors. {}: {}'.format(obj_name, obj_name, + obj)) + if is_export_mode or self._ctx.is_running_on_cpu(is_export_mode): + return + if isinstance(obj, ops.Tensor): + if not obj.get_shape().is_fully_defined(): + raise ValueError( + 'The {} to the model returned by input_fn must have static shape.' + ' Tensor: {}'.format(obj_name, obj)) + else: + for (key, tensor) in obj.items(): + if not tensor.get_shape().is_fully_defined(): + raise ValueError( + 'The {} to the model returned by input_fn must have static ' + 'shape. Key: \'{}\', Tensor: {}'.format( + obj_name, key, tensor)) + + validate(features, 'features') + if labels is not None: + validate(labels, 'labels') + def _call_model_fn(self, features, labels, is_export_mode=False): """Calls the model_fn with required parameters.""" + self._validate_model_features_and_labels(features, labels, is_export_mode) model_fn_args = function_utils.fn_args(self._model_fn) kwargs = {} @@ -1812,11 +1892,6 @@ class TPUEstimator(estimator_lib.Estimator): ... ``` - Current limitations: - -------------------- - - 1. Outside compilation does not work yet (b/79991729). - """ def __init__(self, @@ -1837,7 +1912,8 @@ class TPUEstimator(estimator_lib.Estimator): Args: model_fn: Model function as required by `Estimator`. For training, the returned `EstimatorSpec` cannot have hooks as it is not supported in - `TPUEstimator`. + `TPUEstimator`. Instead, the user can pass the training hooks as + an argument to `TPUEstimator.train()`. model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. If `None`, the model_dir in @@ -2034,10 +2110,21 @@ class TPUEstimator(estimator_lib.Estimator): # Reconstruct `tensors`, but with `tpu_tensors` replaced with # `tpu_tensors_on_cpu`. - new_tensors = [ - tpu_tensors_on_cpu.pop(0) if _is_tpu_tensor(t) else t - for t in tensors - ] + new_tensors = [] + for t in tensors: + if _is_tpu_tensor(t): + new_tensors.append(tpu_tensors_on_cpu.pop(0)) + elif t is None: + new_tensors.append(None) + else: + # Only fetching `tpu_tensors_on_cpu` does not trigger + # TPU computation and blocks, so we add the control dependency here. + control_inputs = (tpu_tensors_on_cpu + if isinstance(tpu_tensors_on_cpu, (list, tuple)) + else (tpu_tensors_on_cpu,)) + with ops.control_dependencies(control_inputs): + new_tensors.append(array_ops.identity(t)) + # Reconstruct `tensors_dict`. new_tensors_dict = nest.pack_sequence_as(tensors_dict, new_tensors) # Reconstruct `export_outputs`. @@ -2898,6 +2985,7 @@ class _StopSignals(object): @staticmethod def should_stop(scalar_stopping_signal): + """Detects whether scalar_stopping_signal indicates stopping.""" if isinstance(scalar_stopping_signal, ops.Tensor): # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF # way to express the bool check whether scalar_stopping_signal is True. @@ -3017,7 +3105,7 @@ class _SignalsHelper(object): def __init__(self, signals): self._signal_keys = [] - for key in sorted(signals.iterkeys()): + for key in sorted(iter(signals.keys())): self._signal_keys.append(key) @property @@ -3029,7 +3117,7 @@ class _SignalsHelper(object): @staticmethod def as_tensor_list(signals): - return [signals[key] for key in sorted(signals.iterkeys())] + return [signals[key] for key in sorted(iter(signals.keys()))] def _verify_cross_hosts_transfer_size(tensor_dict, message): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py index e76cf83e4ddcd86ab3971bcecefe2e2dc979bf63..15f99d7eebddd46f9f6902b68f01e42359a72cbe 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py @@ -19,6 +19,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections + from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.python.ops.losses import losses @@ -32,7 +34,8 @@ class CrossShardOptimizer(optimizer.Optimizer): def __init__(self, opt, reduction=losses.Reduction.MEAN, - name="CrossShardOptimizer"): + name="CrossShardOptimizer", + group_assignment=None): """Construct a new cross-shard optimizer. Args: @@ -40,6 +43,8 @@ class CrossShardOptimizer(optimizer.Optimizer): reduction: The reduction to apply to the shard losses. name: Optional name prefix for the operations created when applying gradients. Defaults to "CrossShardOptimizer". + group_assignment: Optional list of group ids for applying the optimizer + to subgroups. Raises: ValueError: If reduction is not a valid cross-shard reduction. @@ -50,6 +55,35 @@ class CrossShardOptimizer(optimizer.Optimizer): super(CrossShardOptimizer, self).__init__(False, name) self._opt = opt self._reduction = reduction + self._group_assignment = group_assignment + + def _verify_and_get_subgroup_size(self, group_assignment, num_shards): + """Verify group_assignment and get the subgroup size". + + Args: + group_assignment: list of group ids for applying the optimizer + to subgroups. + num_shards: The number of TPU shards. + + Returns: + The size of one subgroup in group_assignment. + + Raises: + ValueError: If group_assignment is invalid. + """ + if not group_assignment: + return None + if len(group_assignment) != num_shards: + raise ValueError("The size of group_assignment does not equal to " + "num_shard({0}). Got group_assignment={1}".format( + num_shards, self._group_assignment)) + subgroup_size_list = dict(collections.Counter(group_assignment)).values() + if all(subgroup_size_list[0] == size for size in subgroup_size_list): + return subgroup_size_list[0] + else: + raise ValueError("The size of each subgroup in group_assignment must " + "be equal. Got group_assignment={}".format( + self._group_assignment)) def compute_gradients(self, loss, var_list=None, **kwargs): """Compute gradients of "loss" for the variables in "var_list". @@ -71,7 +105,8 @@ class CrossShardOptimizer(optimizer.Optimizer): A list of (gradient, variable) pairs. Raises: - ValueError: If not within a tpu_shard_context. + ValueError: If not within a tpu_shard_context or group_assignment is + invalid. """ num_shards = tpu_function.get_tpu_context().number_of_shards if num_shards is None: @@ -79,9 +114,17 @@ class CrossShardOptimizer(optimizer.Optimizer): "CrossShardOptimizer should be used within a tpu_shard_context, but " "got unset number_of_shards. Assuming 1.") num_shards = 1 + + subgroup_size = self._verify_and_get_subgroup_size(self._group_assignment, + num_shards) + if num_shards > 1 and self._reduction == losses.Reduction.MEAN: - scale = 1.0 / num_shards + if self._group_assignment: + scale = 1.0 / subgroup_size + else: + scale = 1.0 / num_shards loss *= scale + return self._opt.compute_gradients(loss, var_list=var_list, **kwargs) def apply_gradients(self, grads_and_vars, global_step=None, name=None): @@ -110,7 +153,8 @@ class CrossShardOptimizer(optimizer.Optimizer): if grad is None: summed_grads_and_vars.append((grad, var)) else: - summed_grads_and_vars.append((tpu_ops.cross_replica_sum(grad), var)) + summed_grads_and_vars.append((tpu_ops.cross_replica_sum( + grad, self._group_assignment), var)) return self._opt.apply_gradients(summed_grads_and_vars, global_step, name) def get_slot(self, *args, **kwargs): diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index 5de55b5f7f2a41ac6edd27e5a102e565f33df12c..76927e62e82d02de172a0851819716dc63180371 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -295,7 +295,7 @@ py_test( tags = ["notsan"], deps = [ ":training_py", - "//tensorflow/contrib/data/python/kernel_tests:dataset_serialization_test", + "//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:gradients", diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py index 409aba817c1ec37003eb98f000f6cf8918234c5d..a2444934bc21d58ed57d15494b3548a31ce3a2df 100644 --- a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py +++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import convert from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes @@ -45,14 +46,14 @@ class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.Dataset): self._input_dataset = input_dataset self._batch_size = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") - # pylint: disable=protected-access if padded_shapes is None: self._padded_shapes = nest.map_structure( - dataset_ops._partial_shape_to_tensor, input_dataset.output_shapes) + convert.partial_shape_to_tensor, input_dataset.output_shapes) else: self._padded_shapes = nest.map_structure_up_to( - input_dataset.output_shapes, dataset_ops._partial_shape_to_tensor, + input_dataset.output_shapes, convert.partial_shape_to_tensor, padded_shapes) + # pylint: disable=protected-access padding_values = ( padding_values if padding_values is not None else dataset_ops._default_padding(input_dataset)) diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py index 0338f409a203c232e63e99534a8f6d6a43fa661e..df0a186f4f6963d7e874bb4ab74a8db7e10a52ee 100644 --- a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py +++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py @@ -19,7 +19,7 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.contrib.training.python.training import tensor_queue_dataset as tqd from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes diff --git a/tensorflow/contrib/verbs/BUILD b/tensorflow/contrib/verbs/BUILD index 9720fd6e8657de18cf8d7565f834568ae52fdbda..19cb8983b6836266ebfac70c54657a96324e8435 100644 --- a/tensorflow/contrib/verbs/BUILD +++ b/tensorflow/contrib/verbs/BUILD @@ -53,12 +53,12 @@ cc_library( ":grpc_verbs_service_impl", ":rdma_mgr", ":verbs_service_proto_cc", + "//tensorflow:grpc++", "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime:session_mgr", "//tensorflow/core/distributed_runtime/rpc:async_service_interface", "//tensorflow/core/distributed_runtime/rpc:grpc_call", "//tensorflow/core/distributed_runtime/rpc:grpc_util", - "@grpc//:grpc++_unsecure", ], alwayslink = 1, ) @@ -69,7 +69,7 @@ cc_library( hdrs = ["grpc_verbs_service_impl.h"], deps = [ ":verbs_service_proto_cc", - "@grpc//:grpc++_unsecure", + "//tensorflow:grpc++", ], ) diff --git a/tensorflow/contrib/verbs/grpc_verbs_service.cc b/tensorflow/contrib/verbs/grpc_verbs_service.cc index 742f946c9536973eb8a6a11afda1b32ae4a7726b..af29abd91feda22824e57c19c13a3f48fb1d61b7 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service.cc +++ b/tensorflow/contrib/verbs/grpc_verbs_service.cc @@ -15,9 +15,9 @@ limitations under the License. #ifdef TENSORFLOW_USE_VERBS -#include "grpc++/alarm.h" -#include "grpc++/grpc++.h" -#include "grpc++/server_builder.h" +#include "grpcpp/alarm.h" +#include "grpcpp/grpcpp.h" +#include "grpcpp/server_builder.h" #include "tensorflow/contrib/verbs/grpc_verbs_service.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc index 991f9a9d8bdf883b1b68bfa1fb6af7bf51b7e66a..4da7b59c69c88a4d04be37543aae7f03decd2c52 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc +++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc @@ -15,14 +15,14 @@ limitations under the License. #include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h" -#include "grpc++/impl/codegen/async_stream.h" -#include "grpc++/impl/codegen/async_unary_call.h" -#include "grpc++/impl/codegen/channel_interface.h" -#include "grpc++/impl/codegen/client_unary_call.h" -#include "grpc++/impl/codegen/method_handler_impl.h" -#include "grpc++/impl/codegen/rpc_service_method.h" -#include "grpc++/impl/codegen/service_type.h" -#include "grpc++/impl/codegen/sync_stream.h" +#include "grpcpp/impl/codegen/async_stream.h" +#include "grpcpp/impl/codegen/async_unary_call.h" +#include "grpcpp/impl/codegen/channel_interface.h" +#include "grpcpp/impl/codegen/client_unary_call.h" +#include "grpcpp/impl/codegen/method_handler_impl.h" +#include "grpcpp/impl/codegen/rpc_service_method.h" +#include "grpcpp/impl/codegen/service_type.h" +#include "grpcpp/impl/codegen/sync_stream.h" namespace tensorflow { diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h index 1f0f10517e98a32ae882c027330091928f1a6ee2..abe5e08b07cd71b7ca28321e6eb2cf0eec5d1b0f 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h +++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h @@ -16,14 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_ #define TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_ -#include "grpc++/impl/codegen/async_stream.h" -#include "grpc++/impl/codegen/async_unary_call.h" -#include "grpc++/impl/codegen/proto_utils.h" -#include "grpc++/impl/codegen/rpc_method.h" -#include "grpc++/impl/codegen/service_type.h" -#include "grpc++/impl/codegen/status.h" -#include "grpc++/impl/codegen/stub_options.h" -#include "grpc++/impl/codegen/sync_stream.h" +#include "grpcpp/impl/codegen/async_stream.h" +#include "grpcpp/impl/codegen/async_unary_call.h" +#include "grpcpp/impl/codegen/proto_utils.h" +#include "grpcpp/impl/codegen/rpc_method.h" +#include "grpcpp/impl/codegen/service_type.h" +#include "grpcpp/impl/codegen/status.h" +#include "grpcpp/impl/codegen/stub_options.h" +#include "grpcpp/impl/codegen/sync_stream.h" #include "tensorflow/contrib/verbs/verbs_service.pb.h" diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index da60b4b169eb6c50e51f36f47c51dc9df619072f..ef8c3f358a523b1a0064527b6a31b650ca9ee7d7 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -72,24 +72,24 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", + "cc_header_only_library", "full_path", "if_android", - "if_not_android_mips_and_mips64", "if_ios", "if_linux_x86_64", "if_mobile", "if_not_mobile", - "if_windows", "if_not_windows", - "tf_copts", + "if_windows", "tf_cc_test", "tf_cc_tests", + "tf_copts", "tf_cuda_library", "tf_gen_op_libs", "tf_generate_proto_text_sources", "tf_genrule_cmd_append_to_srcs", "tf_opts_nortti_if_android", - "cc_header_only_library", + "tf_features_nomodules_if_android", ) load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl") load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") @@ -113,11 +113,11 @@ load( "tf_additional_human_readable_json_deps", "tf_additional_lib_defines", "tf_additional_lib_deps", + "tf_additional_lib_hdrs", + "tf_additional_lib_srcs", "tf_additional_libdevice_data", "tf_additional_libdevice_deps", "tf_additional_libdevice_srcs", - "tf_additional_lib_hdrs", - "tf_additional_lib_srcs", "tf_additional_minimal_lib_srcs", "tf_additional_mpi_lib_defines", "tf_additional_proto_hdrs", @@ -141,8 +141,8 @@ load( ) load( "//tensorflow/core:platform/default/build_config_root.bzl", - "tf_cuda_tests_tags", "if_static", + "tf_cuda_tests_tags", ) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@io_bazel_rules_closure//closure:defs.bzl", "closure_proto_library") @@ -879,6 +879,7 @@ cc_library( hdrs = [ "util/stats_calculator.h", ], + copts = tf_copts(), ) cc_library( @@ -890,6 +891,12 @@ cc_library( ], ) +cc_library( + name = "exec_on_stall", + hdrs = ["util/exec_on_stall.h"], + deps = [":framework_lite"], +) + cc_library( name = "ptr_util", hdrs = ["util/ptr_util.h"], @@ -992,6 +999,7 @@ tf_gen_op_libs( "nn_ops", "no_op", "parsing_ops", + "random_grad", "random_ops", "remote_fused_graph_ops", "resource_variable_ops", @@ -1442,6 +1450,7 @@ filegroup( "lib/png/**/*", "lib/gif/**/*", "util/events_writer.*", + "util/stats_calculator.*", "util/reporter.*", "platform/**/cuda_libdevice_path.*", "platform/default/test_benchmark.*", @@ -1525,6 +1534,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":protos_all_cc_impl", + ":stats_calculator_portable", "//third_party/eigen3", "@double_conversion//:double-conversion", "@nsync//:nsync_cpp", @@ -1565,6 +1575,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":protos_all_cc_impl", + ":stats_calculator_portable", "//third_party/eigen3", "@double_conversion//:double-conversion", "@nsync//:nsync_cpp", @@ -2330,6 +2341,7 @@ FRAMEWORK_INTERNAL_PRIVATE_HEADERS = [ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [ "framework/op_segment.h", "framework/rendezvous.h", # only needed for tests + "framework/resource_var.h", "framework/tensor_reference.h", "framework/tracking_allocator.h", # only needed for tests "framework/unique_tensor_references.h", @@ -2625,6 +2637,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/dma_helper.h", "common_runtime/eigen_thread_pool.h", "common_runtime/executor.h", + "common_runtime/executor_factory.h", "common_runtime/graph_optimizer.h", "common_runtime/local_device.h", "common_runtime/lower_if_op.h", @@ -2674,6 +2687,7 @@ tf_cuda_library( "common_runtime/device_resolver_local.cc", "common_runtime/device_set.cc", "common_runtime/executor.cc", + "common_runtime/executor_factory.cc", "common_runtime/function.cc", "common_runtime/graph_optimizer.cc", "common_runtime/graph_runner.cc", @@ -3255,6 +3269,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "exec_on_stall_test", + size = "small", + srcs = ["util/exec_on_stall_test.cc"], + deps = [ + ":exec_on_stall", + ":framework_lite", + ":test", + ":test_main", + ], +) + tf_cc_test( name = "lib_jpeg_jpeg_mem_unittest", srcs = ["lib/jpeg/jpeg_mem_unittest.cc"], @@ -3346,6 +3372,7 @@ tf_cc_tests( "framework/bfloat16_test.cc", "framework/cancellation_test.cc", "framework/common_shape_fns_test.cc", + "framework/device_base_test.cc", "framework/function_test.cc", "framework/graph_def_util_test.cc", "framework/graph_to_functiondef_test.cc", diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD index 19d643880966f7607405539a5ad43d8e03dc13fb..06b797e32edc046bab498f8d775040d57ef62ce9 100644 --- a/tensorflow/core/api_def/BUILD +++ b/tensorflow/core/api_def/BUILD @@ -4,6 +4,7 @@ # The following targets can be used to access ApiDefs: # :base_api_def # :python_api_def +# :java_api_def package( default_visibility = ["//visibility:private"], @@ -29,6 +30,12 @@ filegroup( visibility = ["//tensorflow:internal"], ) +filegroup( + name = "java_api_def", + srcs = glob(["java_api/*"]), + visibility = ["//tensorflow:internal"], +) + cc_library( name = "excluded_ops_lib", srcs = ["excluded_ops.cc"], diff --git a/tensorflow/core/api_def/base_api/api_def_BatchDatasetV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_BatchDatasetV2.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..0c5b1eb45af6812bdd35e2fef43ac8c02a5b9388 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BatchDatasetV2.pbtxt @@ -0,0 +1,18 @@ +op { + graph_op_name: "BatchDatasetV2" + visibility: HIDDEN + in_arg { + name: "batch_size" + description: <